PyTorch implementation of DreamerV2 model-based RL algorithm

Overview

PyDreamer

Reimplementation of DreamerV2 model-based RL algorithm in PyTorch.

The official DreamerV2 implementation can be found here.

Features

...

Running the code

Running locally

Install dependencies

pip3 install -r requirements.txt

Get Atari ROMs

pip3 install atari-py==0.2.9
wget -L -nv http://www.atarimania.com/roms/Roms.rar
apt-get install unrar                                   # brew install unar (Mac)
unrar x Roms.rar                                        # unar -D Roms.rar  (Mac)
unzip ROMS.zip
python3 -m atari_py.import_roms ROMS
rm -rf Roms.rar *ROMS.zip ROMS

Run training (debug CPU mode)

python pydreamer/train.py --configs defaults atari debug --env_id Atari-Pong

Run training (full GPU mode)

python pydreamer/train.py --configs defaults atari atari_pong --run_name atari_pong_1
Comments
  • Dataset in test environment

    Dataset in test environment

    Hi! Thanks for your work on pytorch implementation of dreamerv2. It helps me a lot.

    I am just curious on the reason for building two different datasets for the both train and test environments. It seems not crucial for the training, but why do we need a dataset in test env? Please let me know if this is obviously necessary and I am missing something.

    opened by gimme1dollar 2
  • Minigrid environments not working

    Minigrid environments not working

    Thanks for this Pytorch implementation of Dreamer-v2! I'm trying to get the code working with the Minigrid environments. However, I'm encountering the following error:

    Traceback (most recent call last): File "/home/steph/projects/pydreamer/train.py", line 612, in run(conf) File "/home/steph/projects/pydreamer/train.py", line 227, in run model.training_step(obs, File "/home/steph/projects/pydreamer/pydreamer/models/dreamer.py", line 122, in training_step loss_probe, metrics_probe, tensors_probe = self.probe_model.training_step(features.detach(), obs) File "/home/steph/projects/pydreamer/pydreamer/models/probes.py", line 38, in training_step map_coord = insert_dim(obs['map_coord'], 2, I) KeyError: 'map_coord'

    Here's the command I ran: xvfb-run -a -s "-screen 0 1400x900x24" python train.py --config defaults minigrid debug --env_id MiniGrid-Empty-8x8-v0 --device cuda

    opened by stephmilani 2
  • Docker image build fail

    Docker image build fail

    Greetings.

    Thank you very much for your reproduction of the original Dreamver-v2 code in Pytorch, as well as open sourcing it on Github. I have been looking through various Torch ports of the Dreamer-v2 algorithm for some experiments in my research, and your implementation definitely caught my attention, being very complete.

    While I managed to get some the PyDreamer agent run locally using the provided instructions, I encountered a problem during the build of the Docker image I though you might want to know about.

    Namely, after executing docker build . -f Dockerfile -t pydreamer, it returns the following error:

    
    (pydreamer) d055@akira:~/random/rl/pydreamer$ docker build . -f Dockerfile -t pydreamer
    Sending build context to Docker daemon  330.2kB
    Step 1/32 : ARG ENV=standard
    Step 2/32 : FROM pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel AS base
     ---> c7e20104018e
    Step 3/32 : RUN apt-get update && apt-get install -y     git xvfb python3.7-dev python3-setuptools     libglu1-mesa libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev mesa-utils freeglut3 freeglut3-dev     libglew2.0 libglfw3 libglfw3-dev zlib1g zlib1g-dev libsdl2-dev libjpeg-dev lua5.1 liblua5.1-0-dev libffi-dev     build-essential cmake g++-4.8 pkg-config software-properties-common gettext     ffmpeg patchelf swig unrar unzip zip curl wget tmux     && rm -rf /var/lib/apt/lists/*
     ---> Using cache
     ---> 2f7651b27698
    Step 4/32 : FROM base AS standard-env
     ---> 2f7651b27698
    Step 5/32 : RUN pip3 install atari-py==0.2.9
     ---> Using cache
     ---> f7bc1331fcbb
    Step 6/32 : RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar &&     unrar x Roms.rar &&     unzip ROMS.zip &&     python3 -m atari_py.import_roms ROMS &&     rm -rf Roms.rar ROMS.zip ROMS
     ---> Using cache
     ---> 5cd00f9e095d
    Step 7/32 : RUN mkdir -p /root/.mujoco &&     cd /root/.mujoco &&     wget -nv https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz &&     tar -xf mujoco.tar.gz &&     rm mujoco.tar.gz
     ---> Using cache
     ---> b1decf2d8fd8
    Step 8/32 : RUN pip3 install dm_control
     ---> Using cache
     ---> ab0179162e2a
    Step 9/32 : FROM base AS dmlab-env
     ---> 2f7651b27698
    Step 10/32 : RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" |     tee /etc/apt/sources.list.d/bazel.list &&     curl https://bazel.build/bazel-release.pub.gpg |     apt-key add - &&     apt-get update && apt-get install -y bazel
     ---> Using cache
     ---> 9851f7e81633
    Step 11/32 : RUN git clone https://github.com/deepmind/lab.git /dmlab
     ---> Using cache
     ---> e8a16a62e893
    Step 12/32 : WORKDIR /dmlab
     ---> Using cache
     ---> 56da3d82b379
    Step 13/32 : RUN git checkout "937d53eecf7b46fbfc56c62e8fc2257862b907f2"
     ---> Using cache
     ---> cf2e70fb4e1a
    Step 14/32 : RUN ln -s '/opt/conda/lib/python3.7/site-packages/numpy/core/include/numpy' /usr/include/numpy &&     sed -i '[email protected]@python3.7@g' python.BUILD &&     sed -i 's@glob(\[@glob(["include/numpy/\*\*/*.h", @g' python.BUILD &&     sed -i 's@: \[@: ["include/numpy", @g' python.BUILD &&     sed -i 's@650250979303a649e21f87b5ccd02672af1ea6954b911342ea491f351ceb7122@1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730@g' WORKSPACE &&     sed -i 's@rules_cc-master@rules_cc-main@g' WORKSPACE &&     sed -i 's@rules_cc/archive/master@rules_cc/archive/main@g' WORKSPACE &&     bazel build -c opt python/pip_package:build_pip_package --incompatible_remove_legacy_whole_archive=0
     ---> Running in 189631cf8219
    Extracting Bazel installation...
    Starting local Bazel server and connecting to it...
    Loading: 
    Loading: 0 packages loaded
    Analyzing: target //python/pip_package:build_pip_package (1 packages loaded, 0 targets configured)
    Analyzing: target //python/pip_package:build_pip_package (7 packages loaded, 15 targets configured)
    INFO: SHA256 (https://github.com/bazelbuild/rules_cc/archive/main.zip) = 3839996049629e6377abdfd04681ddeeb0cc3db13b9d2ff81bf46700cb4529f7
    DEBUG: Rule 'rules_cc' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = "3839996049629e6377abdfd04681ddeeb0cc3db13b9d2ff81bf46700cb4529f7"
    DEBUG: Repository rules_cc instantiated at:
      /dmlab/WORKSPACE:11:13: in <toplevel>
    Repository rule http_archive defined at:
      /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
    Analyzing: target //python/pip_package:build_pip_package (22 packages loaded, 213 targets configured)
    INFO: SHA256 (https://github.com/abseil/abseil-cpp/archive/master.zip) = 6d33798883560650cb9484a915e5085d251b61c14d8937ad714448577786c0fa
    DEBUG: Rule 'com_google_absl' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = "6d33798883560650cb9484a915e5085d251b61c14d8937ad714448577786c0fa"
    DEBUG: Repository com_google_absl instantiated at:
      /dmlab/WORKSPACE:17:13: in <toplevel>
    Repository rule http_archive defined at:
      /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
    INFO: Repository jpeg_archive instantiated at:
      /dmlab/WORKSPACE:45:13: in <toplevel>
    Repository rule http_archive defined at:
      /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
    WARNING: Download from http://www.ijg.org/files/jpegsrc.v9c.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
    ERROR: An error occurred during the fetch of repository 'jpeg_archive':
       Traceback (most recent call last):
            File "/root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl", line 111, column 45, in _http_archive_impl
                    download_info = ctx.download_and_extract(
    Error in download_and_extract: java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
    ERROR: Error fetching repository: Traceback (most recent call last):
            File "/root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl", line 111, column 45, in _http_archive_impl
                    download_info = ctx.download_and_extract(
    Error in download_and_extract: java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
    Analyzing: target //python/pip_package:build_pip_package (31 packages loaded, 2049 targets configured)
    INFO: Repository glib_archive instantiated at:
      /dmlab/WORKSPACE:34:13: in <toplevel>
    Repository rule http_archive defined at:
      /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
    INFO: Repository png_archive instantiated at:
      /dmlab/WORKSPACE:64:13: in <toplevel>
    Repository rule http_archive defined at:
      /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
    ERROR: /dmlab/q3map2/BUILD:54:10: //q3map2:q3map2 depends on @jpeg_archive//:jpeg in repository @jpeg_archive which failed to fetch. no such package '@jpeg_archive//': java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
    ERROR: Analysis of target '//python/pip_package:build_pip_package' failed; build aborted: Analysis failed
    INFO: Elapsed time: 7.022s
    INFO: 0 processes.
    FAILED: Build did NOT complete successfully (31 packages loaded, 2049 targets configured)
    FAILED: Build did NOT complete successfully (31 packages loaded, 2049 targets configured)
    The command '/bin/sh -c ln -s '/opt/conda/lib/python3.7/site-packages/numpy/core/include/numpy' /usr/include/numpy &&     sed -i '[email protected]@python3.7@g' python.BUILD &&     sed -i 's@glob(\[@glob(["include/numpy/\*\*/*.h", @g' python.BUILD &&     sed -i 's@: \[@: ["include/numpy", @g' python.BUILD &&     sed -i 's@650250979303a649e21f87b5ccd02672af1ea6954b911342ea491f351ceb7122@1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730@g' WORKSPACE &&     sed -i 's@rules_cc-master@rules_cc-main@g' WORKSPACE &&     sed -i 's@rules_cc/archive/master@rules_cc/archive/main@g' WORKSPACE &&     bazel build -c opt python/pip_package:build_pip_package --incompatible_remove_legacy_whole_archive=0' returned a non-zero code: 1
    

    I am not very familiar with bazel, but it seems that one of the dependencies, i.e. jpeg_archive that bazel is in charge of installing is not found on the remote repository.

    Did you happen to encounter this problem during your experiments ?

    Looking forward to hear back from you. Best regards.

    opened by dosssman 2
  • about the generators

    about the generators

    Hi!

    I don't fully understand how the code manages the available hardware resources, and I could use some advice on how to accelerate training. e.g. in an environment with multiple GPUs and multiple CPUs, what changes should I do to make sure I make use of these resources?

    Thank you very much!

    opened by roger-creus 1
  • Consistently getting generator process error when training model

    Consistently getting generator process error when training model

    Hi, I'm trying to train a model on one of the standard OpenAI gym environments. However, even when I use the "debug" settings, I consistently encounter this error: image

    Any tips for how to resolve this?

    Edit: It was an issue on my end :)

    opened by stephmilani 0
  • Batchnorm: Expected more than 1 value per channel

    Batchnorm: Expected more than 1 value per channel

    When the generator is trying run a NetworkPolicy it has a batch size of 1 which doesn't work well with BatchNorm Layers (see error in title). The error is due to the fact that the mean and var of the BatchNorm are still getting updated even with torch.no_grad and the error for batch size 1 is probably to indicate that this is not the intended behavior.

    Do you think it is okay to run the NetworkPolicy in eval mode?

    opened by truncs 1
  • Trouble with image_categorical, map_categorical for MiniGrid envs

    Trouble with image_categorical, map_categorical for MiniGrid envs

    Hi! I'm trying to use your codebase for various MiniGrid environments. Do you have any tips for setting the values for image_channels and map_channels?

    If I just use the default config, I get an out-of-bounds error in the img_to_onehot function.

    opened by stephmilani 2
Owner
null
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 5, 2023
An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow implementation of SERank model. The code is developed based on TF-Ranking.

SERank An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow

Zhihu 44 Oct 20, 2022
PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

Berivan Isik 8 Dec 8, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 8, 2022
RL algorithm PPO and IRL algorithm AIRL written with Tensorflow.

RL algorithm PPO and IRL algorithm AIRL written with Tensorflow. They have a parallel sampling feature in order to increase computation speed (especially in high-performance computing (HPC)).

Fangjian Li 3 Dec 28, 2021
Sequential Model-based Algorithm Configuration

SMAC v3 Project Copyright (C) 2016-2018 AutoML Group Attention: This package is a reimplementation of the original SMAC tool (see reference below). Ho

AutoML-Freiburg-Hannover 778 Jan 5, 2023
😇A pyTorch implementation of the DeepMoji model: state-of-the-art deep learning model for analyzing sentiment, emotion, sarcasm etc

------ Update September 2018 ------ It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such t

Hugging Face 865 Dec 24, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Luna Yue Huang 41 Oct 29, 2022
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

null 770 Jan 2, 2023
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

null 121 Nov 5, 2022
A Pytorch implementation of the multi agent deep deterministic policy gradients (MADDPG) algorithm

Multi-Agent-Deep-Deterministic-Policy-Gradients A Pytorch implementation of the multi agent deep deterministic policy gradients(MADDPG) algorithm This

Phil Tabor 159 Dec 28, 2022
An unofficial PyTorch implementation of a federated learning algorithm, FedAvg.

Federated Averaging (FedAvg) in PyTorch An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-E

Seok-Ju Hahn 123 Jan 6, 2023
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm.

REDQ source code Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm. Paper link: https://arxiv.org/abs/2101.05

null 109 Dec 16, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

ClusterGCN ⠀⠀ A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019). A

Benedek Rozemberczki 697 Dec 27, 2022