DrQ-v2: Improved Data-Augmented Reinforcement Learning

Related tags

Deep Learning drqv2
Overview

DrQ-v2: Improved Data-Augmented RL Agent

Method

DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ, an actor-critic approach that uses data augmentation to learn directly from pixels. We introduce several improvements including:

  • Switch the base RL learner from SAC to DDPG.
  • Incorporate n-step returns to estimate TD error.
  • Introduce a decaying schedule for exploration noise.
  • Make implementation 3.5 times faster.
  • Find better hyper-parameters.

These changes allow us to significantly improve sample efficiency and wall-clock training time on a set of challening tasks from the DeepMind Control Suite compared to prior methods. Furthermore, DrQ-v2 is able to solve complex humanoid locomotion tasks directly from pixel observations, previously unattained by model-free RL.

Citation

If you use this repo in your research, please consider citing the paper as follows:

@article{yarats2021drqv2,
  title={Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning},
  author={Denis Yarats and Rob Fergus and Alessandro Lazaric and Lerrel Pinto},
  journal={arXiv preprint arXiv:},
  year={2021}
}

Instructions

Install dependencies:

conda env create -f conda_env.yml
conda activate drqv2

Train the agent:

python train.py task=quadruped_walk

Monitor results:

tensorboard --logdir exp_local

License

The majority of DrQ-v2 is licensed under the MIT license, however portions of the project are available under separate license terms: DeepMind is licensed under the Apache 2.0 license.

Comments
  • Replayloader doesn't work for Atari

    Replayloader doesn't work for Atari

    Have you tried using this replay loader with Atari? I keep getting this error unless I set the num replay workers to 1:

    File "/u/slerman/miniconda3/envs/agi/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
        data = fetcher.fetch(index)
      File "/u/slerman/miniconda3/envs/agi/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
        data.append(next(self.dataset_iter))
      File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 176, in __iter__
        yield self._sample()
      File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 159, in _sample
        episode = self._sample_episode()
      File "/home/cxu-serve/u1/slerman/drqv2/replay_buffer.py", line 99, in _sample_episode
        eps_fn = random.choice(self._episode_fns)
      File "/u/slerman/miniconda3/envs/agi/lib/python3.8/random.py", line 290, in choice
        raise IndexError('Cannot choose from an empty sequence') from None
    IndexError: Cannot choose from an empty sequence
    
    
    Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
    

    Edit: Sorry, originally posted the wrong trace.

    opened by slerman12 9
  • Get to 96 frame throughput?

    Get to 96 frame throughput?

    Hi! First of all, thank you very much for the nice paper and open sourcing this great codebase, I think it's really nice and clean.

    I have been experimenting with the codebase recently, however, I found that on a V100 gpu the maximum FPS I can get is around 60 and not 96 as reported in the paper, I found this on multiple environments, and on walker where a larger batch size is used it will be a bit more slower. So I'm just wondering if there is any detail that I should pay attention to in order to achieve a FPS of 96? For example, do we need to set number of replay workers to a larger number, or do we need to disable tensorboard?

    I understand that it's always possible that there is sth wrong with my own environment or my hardware, but just want to check with you to see if I missed anything important.

    Thank you so much!

    opened by watchernyu 8
  • reproducing results for humanoid run

    reproducing results for humanoid run

    Thanks for open sourcing this! I wanted to replicate the results on the humanoid run/stand task. Using the default parameter I'm able to get similar result on the walker stand task, i.e. ~400 reward after 15 million stems. However, for the walker run task no learning seems to happen, see attached figure which is averaged across 5 runs. The only change I've made to the parameters is setting stddev_schedule: 'linear(1.0,0.1,2000000). Are there any other modifications necessary for the humanoid run task? Thanks!!

    perf
    opened by snailrowen1337 7
  • Not meant as an issue, but a little perplexed by feature_dim...

    Not meant as an issue, but a little perplexed by feature_dim...

    I noticed feature_dim is set to 50, which is quite a bottleneck from the encoding dim of 32 * 35 * 35 and the downstream hidden_dim of 1024. Very interesting. Do you think the bottleneck helps create some kind of better compression for learning?

    opened by slerman12 5
  • Make ExtendedTimeStep __getitem__ more robust

    Make ExtendedTimeStep __getitem__ more robust

    I run into the following error running drqv2 training. I am not sure if it's caused by python version difference (I'm on python 3.7.11).

    Traceback (most recent call last): File "train.py", line 229, in main workspace.train() File "/home/desaixiecvc/github/drqv2original/train.py", line 136, in train self.replay_storage.add(time_step) File "/home/desaixiecvc/github/drqv2original/replay_buffer.py", line 50, in add value = time_step[spec.name] File "/home/desaixiecvc/github/drqv2original/dmc.py", line 33, in getitem return getattr(self, attr) File "/home/desaixiecvc/github/drqv2original/dmc.py", line 33, in getitem return getattr(self, attr) TypeError: getattr(): attribute name must be string

    getitem first get spec.name (a string) as attr, and returns the index of the attr in ExtendedTimeStep. Then, when retrieving the item with the index (int), getitem gets called again with the index as attr, which caused the error. I added a type check on attr to solve this.

    CLA Signed 
    opened by desaixie 3
  • Question regarding indices for replay buffer

    Question regarding indices for replay buffer

    I am a little confused regarding how the indices work for the replay buffer. Specifically, a "dummy" transition is repeatedly references in replay_buffer.py, and there are some +/- 1 made to the indices:

    idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
    obs = episode['observation'][idx - 1]
    action = episode['action'][idx]
    next_obs = episode['observation'][idx + self._nstep - 1]
    reward = np.zeros_like(episode['reward'][idx])
    

    From reading the code, it seems like the storage layout is

    ----------------------------------------------------------
    rewards           |   None   |  reward_0  |  reward_1  |   ........
    ----------------------------------------------------------
    observations      |   obs_0  |   obs_1    |   obs_2    |   ........
    ----------------------------------------------------------
    actions           |    None  |  action_0  |  action_1  |   .......
    ----------------------------------------------------------
    

    Is that the correct interpretation of the memory layout? And if so, why is the offset used?

    Thanks for open-sourcing this!!

    opened by snailrowen1337 3
  • Can't run on school compute...

    Can't run on school compute...

    Hi, I've been going back and forth with a system admin trying to get this to run, but we can't get passed an OpenGL error. This is is how we installed everything:

    module load mesa  
    module load glfw/3.3.2/b2   
    module load mujoco/200 
    
    git clone [email protected]:facebookresearch/drqv2.git
    cd drqv2 
    
    conda env create -f conda_env.yml
    conda activate drqv2
    
    python train.py task=quadruped_walk
    

    Then this returns the following error no matter what we do:

    OpenGL.raw.EGL._errors.EGLError: EGLError(
            err = EGL_BAD_PARAMETER,
            baseOperation = eglGetPlatformDisplayEXT,
            cArguments = (
                    12607,
                    <OpenGL._opaque.EGLDeviceEXT_pointer object at 0x2ab4078bd240>,
                    None,
            ),
            result = <OpenGL._opaque.EGLDisplay_pointer object at 0x2ab4078bd5c0>
    )
    (/s
    

    Here is what the system admin sent me:

    hi Sam unfortunately I’ve had no luck getting this to work besides spending quite a bit of time on it today. I get the same EGL error unless I’m on a visual node (bhx nodes..) but if I do everything (including the build/install) on a visual node I get a core dump with Illegal Instruction. it is possible that’s due to one of the dependencies from pip/conda and if it were compiled from source on a bhx node that might go away but I have no idea which one...

    Any ideas?

    opened by slerman12 3
  • Parallelize ensemble of Q functions into a single model

    Parallelize ensemble of Q functions into a single model

    Parallelizing yields a 5-10% speedup on my machine with standard batch size and network width and should not affect any of the functionality. This change should also enable to trivially and efficiently scale up to larger number Q-functions.

    CLA Signed 
    opened by Aladoro 3
  • Simple modular configuration refactoring

    Simple modular configuration refactoring

    Refactored the hydra configurations to be divided into agent-specific (agent_cfg) and environment-specific (env_cfg), to ease experimentation. In particular, the only argument that needs to be overridden for any of the experiments with DrQ-v2 is env_cfg.

    I have also added the relative configurations for all medium benchmark tasks, following the details from the paper.

    CLA Signed 
    opened by Aladoro 3
  • Evaluation video overwritten

    Evaluation video overwritten

    I have noticed, in this line when you are saving evaluation episodes, that you are actually overwriting previous evaluation videos of the same global frame when num_eval_episodes > 1.

                self.video_recorder.save(f'{self.global_frame}.mp4')
    

    The attribute self.global_frame doesn't change during the self.cfg.num_eval_episodes evaluations in the loop.

    I would propose this line instead

                self.video_recorder.save(f'{self.global_frame}_{episode}.mp4')
    

    or other things like that.

    opened by medric49 2
  • Truncation for exploration?

    Truncation for exploration?

    I'm reading the paper and code, and can't follow the truncation process. Table 2 sets exploration stddev. clip equal to 0.3, so I assume that the exploration noise is clipped. However, the action seems to be selected by action = dist.sample(clip=None) which does no clipping. Instead clipping is seemingly applied during training with dist.sample(clip=self.stddev_clip). Am I misunderstanding something here? Thanks!!

    opened by AaronLiu1997 2
  • Why using 'F.grid_sample()' in class RandomShiftsAug?

    Why using 'F.grid_sample()' in class RandomShiftsAug?

    Actually, for the random shift operation, we can easily choose two random variable as the shift axis like this:

    assert origin_image.shape=(512,3,84,84)
    assert pad_image.shape=(512,3,92,92)
    shift_x = random.randint(0,92-84-1)
    shift_y = random.randint(0,92-84-1)
    aug_image = pad_image[:,:,x:x+84,y:y+84]
    

    Then we can get the augmented image.

    I guess interpolation is the reason you guys choose F.grid_sample()?

    opened by Guozheng-Ma 0
  • Does drqv2 support multi-processing for the env interaction?

    Does drqv2 support multi-processing for the env interaction?

    Hi there, I wonder if drq2 supports multi-processing for the env interaction during the training and evaluation since it seems to cost most of the time.

    opened by HeegerGao 0
  • Error when Running Training Command

    Error when Running Training Command

    Hello,

    When I ran the training command given in the readme python train.py task=quadruped_walk I got this error:

    File "/home/anavani/anaconda3/lib/python3.9/site-packages/hydra/_internal/defaults_list.py", line 168, in ensure_overrides_used raise ConfigCompositionException(msg) hydra.errors.ConfigCompositionException: Could not override 'task'. Did you mean to override task@_global_? To append to your default list use +task=quadruped_walk

    I changed the command to python train.py +task=quadruped_walk and this seemed to fix the issue. However, after I let it train for a bit, I got this error

    It seems as if the +task=quadruped_walk is causing an EOF error, but I'm not sure what is causing the seocnd error. I would really appreciate any help. @denisyarats @Aladoro @desaixie @medric49

    opened by oofmeister27 6
  • Dreamerv2 learning curve

    Dreamerv2 learning curve

    Hi, @denisyarats, In the paper, you said that you run the dreamerv2 to get the learning curve on 12 Deep mind control suit tasks. Could you kindly share the learning curve of DreamerV2?

    BR,

    opened by zivzone 0
  •  Hyperparameters optimization

    Hyperparameters optimization

    🚀 Feature Request

    Hyperparameter optimization for the fish environment of dm_control.

    Motivation

    I tried easy,medium,hard sets for the upright and swim task of the fish environment but neither of them seemed to work for the swim task.

    Could you provide the way(script) you found the hyperparameters for other envs? Otherwise I can open the pull request if I succeed for the fish env.

    opened by ss555 0
  • A question about num_train_frames

    A question about num_train_frames

    Hi,

    Great work, thanks for making it open-source!

    Could you tell me why in the configs you use more num_train_frames than reported in the paper? In the paper, the numbers of frames for easy/medium/hard tasks are 1/3/30 * 10^6; however, in the config files (easy.yaml, medium.yaml, hard.yaml) the numbers are 1.1/3.1/30.1 * 10^6

    opened by AwesomeLemon 0
Owner
Facebook Research
Facebook Research
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 49 Nov 28, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

Yu Meng 60 Dec 30, 2022
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv SAINT: Improved Neural Networks for Tabular Data via Row Atte

Gowthami Somepalli 284 Dec 21, 2022
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 4, 2023
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

null 139 Dec 29, 2022
TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

null 98 Dec 15, 2022
A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Yasmeen Brain 10 Oct 6, 2022
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

THUDM 58 Dec 17, 2022
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Salesforce 72 Dec 5, 2022
Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation (CoRL 2021)

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation [Project website] [Paper] This project is a PyTorch i

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 6 Feb 28, 2022
Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

This codebase is being actively maintained, please create and issue if you have issues using it Basics All data files are included under losses and ea

J K Terry 32 Nov 9, 2021
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 1, 2023
A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

null 195 Dec 7, 2022
Official PyTorch Implementation of Embedding Transfer with Label Relaxation for Improved Metric Learning, CVPR 2021

Embedding Transfer with Label Relaxation for Improved Metric Learning Official PyTorch implementation of CVPR 2021 paper Embedding Transfer with Label

Sungyeon Kim 37 Dec 6, 2022
Puzzle-CAM: Improved localization via matching partial and full features.

Puzzle-CAM The official implementation of "Puzzle-CAM: Improved localization via matching partial and full features".

Sanghyun Jo 150 Nov 14, 2022