A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview

Overview

This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Model-Based Policy Optimization.

This code is based on a previous paper in the NeurIPS reproducibility challenge that reproduces the result with a tensorflow ensemble model but shows a significant drop in performance with a pytorch ensemble model. This code re-implements the ensemble dynamics model with pytorch and closes the gap.

Reproduced results

The comparison are done on two tasks while other tasks are not tested. But on the tested two tasks, the pytorch implementation achieves similar performance compared to the official tensorflow code. alt text alt text

Dependencies

MuJoCo 1.5 & MuJoCo 2.0

Usage

python main_mbpo.py --env_name 'Walker2d-v2' --num_epoch 300 --model_type 'pytorch'

python main_mbpo.py --env_name 'Hopper-v2' --num_epoch 300 --model_type 'pytorch'

Reference

Comments
  • Could you please add a requirements.txt file?

    Could you please add a requirements.txt file?

    Hi, Really appreciate your reimplementation of MBPO with Pytorch! However, there are several versions of TF and Pytorch, and the numpy versions they depend on are different to mujoco_py which will lead to a dependency conflict.

    Will you add the requirements.txt of your environment and therefore i can reproduce the experiments? Thanks a lot!

    opened by Joy1112 1
  • Missing utils

    Missing utils

    Hello, Thanks for your awesome pytorch reimplementation! I'd like to have a try but I notice that I cannot find the utils in the main_mbpo.py file. May I have your help? Thanks!

    opened by TsuTikgiau 1
  • test finite loop problem & epoch lenght multiplied bug fix

    test finite loop problem & epoch lenght multiplied bug fix

    1. epoch length is multiplied twice -> fix, related with #3 + by testing with Hopper & Walker #4 solved
    2. for some env that could not return done for specific epoch may suffer from infinite loop -> fix
    3. minor argument fix for better interpretation
    opened by songminjae 0
  • Very slow runtime caused by `torch.autograd.set_detect_anomaly(True)`

    Very slow runtime caused by `torch.autograd.set_detect_anomaly(True)`

    I found this line that causes the extreme slowdown in runtime (thousands times slower). https://github.com/Xingyu-Lin/mbpo_pytorch/blob/fe3c78c474d188c16a026051b92f8a2e84fa9387/sac/sac.py#L11

    Set to False returns to normal running speed, just a heads-up.

    opened by mickelliu 0
  • rollout_batch_size

    rollout_batch_size

    rollout_batch_size is default to 100k, which is what I don't understand? Does this mean even is real data is something like 5k, you still sample each data 20 times, and produce 100k data each time you call that function??

    opened by ChenDRAG 0
  • Error when trying to run

    Error when trying to run

    Hey, thank you for your work but sadly I'm not able to run your code I'm getting this inplace operation error. Weirdly that this only happens to me, I was just cloning the repo and running your example command.

    File "mbpo.py", line 267, in
    main() File "mbpo.py", line 263, in main train(args, env_sampler, predict_env, agent, env_pool, model_pool) File "mbpo.py", line 124, in train train_policy_steps += train_policy_repeats(args, total_step, train_policy_steps, cur_step, env_pool, model_pool, agent) File "mbpo.py", line 220, in train_policy_repeats agent.update_parameters((batch_state, batch_action, batch_reward, batch_next_state, batch_done), args.policy_train_batch_size, i ) File "/shared/sebastian/replication-mbpo/sac/sac.py", line 89, in update_parameters policy_loss.backward() File "/shared/sebastian/miniconda3/envs/rrc_simulation/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/shared/sebastian/miniconda3/envs/rrc_simulation/lib/python3.6/site-packages/torch/autograd/init.py", line 132, in backw ard allow_unreachable=True) # allow_unreachable flag RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTenso r [256, 1]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows th e operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

    opened by BY571 0
  • cannot reproduce

    cannot reproduce

    hi, I ran the hopper experiment with the provided command, and now the reward during the 65k-68k envstep is between 400 and 700, which is much lower than the provided figure. image Is there anything that I missed potentially?

    opened by lichuminglcm 2
  • Epoch length?

    Epoch length?

    Hi,

    Thank you for your code. It is really helpful.

    Could you please check the line 115 in the main_mbpo.py? Since start_step will become larger and larger, if the condition is cur_step >= start_step + epoch_length, the truth epoch_length will also become larger and larger. So, is it a bug? Should we use

    cur_step >= args.epoch_length

    Correct me if I am wrong.

    Thanks

    https://github.com/Xingyu-Lin/mbpo_pytorch/blob/43c8a55fa7353c6aed97525d0ecd5cb903b55377/main_mbpo.py#L115

    ` cur_step = total_step - start_step

            if cur_step >= start_step + args.epoch_length and len(env_pool) > args.min_pool_size:
                break
    

    `

    opened by xfdywy 2
Owner
Xingyu Lin
PhD student in the field of Reinforcement Learning, Vision and Robotics @ CMU
Xingyu Lin
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 3, 2023
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
Reinforcement Learning with Q-Learning Algorithm on gym's frozen lake environment implemented in python

Reinforcement Learning with Q Learning Algorithm Q learning algorithm is trained on the gym's frozen lake environment. Libraries Used gym Numpy tqdm P

null 1 Nov 10, 2021
DI-HPC is an acceleration operator component for general algorithm modules in reinforcement learning algorithms

DI-HPC: Decision Intelligence - High Performance Computation DI-HPC is an acceleration operator component for general algorithm modules in reinforceme

OpenDILab 185 Dec 29, 2022
PPO is a very popular Reinforcement Learning algorithm at present.

PPO is a very popular Reinforcement Learning algorithm at present. OpenAI takes PPO as the current baseline algorithm. We use the PPO algorithm to train a policy to give the best action in any situation.

Rosefintech 11 Aug 23, 2021
Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Robotic AI & Learning Lab Berkeley 997 Dec 30, 2022
Multi-agent reinforcement learning algorithm and environment

Multi-agent reinforcement learning algorithm and environment [en/cn] Pytorch implements multi-agent reinforcement learning algorithms including IQL, Q

万鲲鹏 7 Sep 20, 2022
Model-based reinforcement learning in TensorFlow

Bellman Website | Twitter | Documentation (latest) What does Bellman do? Bellman is a package for model-based reinforcement learning (MBRL) in Python,

null 46 Nov 9, 2022
mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms.

mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms. It provides easily interchangeable modeling and planning components, and a set of utility functions that allow writing model-based RL algorithms with only a few lines of code.

Facebook Research 724 Jan 4, 2023
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
On the model-based stochastic value gradient for continuous reinforcement learning

On the model-based stochastic value gradient for continuous reinforcement learning This repository is by Brandon Amos, Samuel Stanton, Denis Yarats, a

Facebook Research 46 Dec 15, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

null 118 Dec 15, 2022
An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow implementation of SERank model. The code is developed based on TF-Ranking.

SERank An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow

Zhihu 44 Oct 20, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 8, 2022
PyBullet CartPole and Quadrotor environments—with CasADi symbolic a priori dynamics—for learning-based control and reinforcement learning

safe-control-gym Physics-based CartPole and Quadrotor Gym environments (using PyBullet) with symbolic a priori dynamics (using CasADi) for learning-ba

Dynamic Systems Lab 300 Dec 28, 2022
​TextWorld is a sandbox learning environment for the training and evaluation of reinforcement learning (RL) agents on text-based games.

TextWorld A text-based game generator and extensible sandbox learning environment for training and testing reinforcement learning (RL) agents. Also ch

Microsoft 983 Dec 23, 2022
Learning to Communicate with Deep Multi-Agent Reinforcement Learning in PyTorch

Learning to Communicate with Deep Multi-Agent Reinforcement Learning This is a PyTorch implementation of the original Lua code release. Overview This

Minqi 297 Dec 12, 2022