Optimal Model Design for Reinforcement Learning
This repository contains JAX code for the paper
Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation
by Evgenii Nikishin, Romina Abachi, Rishabh Agarwal, and Pierre-Luc Bacon.
Summary
Model based reinforcement learning typically trains the dynamics and reward functions by minimizing the error of predictions. The error is only a proxy to maximizing the sum of rewards, the ultimate goal of the agent, leading to the objective mismatch. We propose an end-to-end algorithm called Optimal Model Design (OMD) that optimizes the returns directly for model learning. OMD leverages the implicit function theorem to optimize the model parameters and forms the following computational graph:
Installation
We assume that you use Python 3. To install the necessary dependencies, run the following commands:
1. virtualenv ~/env_omd
2. source ~/env_omd/bin/activate
3. pip install -r requirements.txt
To use JAX with GPU, follow the official instructions. To install MuJoCo, check the instructions.
Run
For historical reasons, the code is divided into 3 parts.
Tabular
All results for the tabular experiments could be reproduced by running the tabular.ipynb
notebook.
To open the notebook in Google Colab, use this link.
CartPole
To train the OMD agent on CartPole, use the following commands:
cd cartpole
python train.py --agent_type omd
We also provide the implementation of the corresponding MLE and VEP baselines. To train the agents, change the --agent_type
flag to mle
or vep
.
MuJoCo
To train the OMD agent on MuJoCo HalfCheetah-v2, use the following commands:
cd mujoco
python train.py --config.algo=omd
To train the MLE baseline, change the --config.algo
flag to mle
.
Acknowledgements
- Tabular experiments are based on the code from the library for fixed points in JAX
- Code for MuJoCo is based on the implementation of SAC in JAX
- Code for CartPole reuses parts of the SAC implementation in PyTorch
- For experimentation, we used a moditication of the slurm runner