Reinforcement learning library in JAX.

Overview

Magi RL library in JAX

Installation | Agents | Examples | Contributing | Documentation

pytest Code style: black

Magi is a RL library developed on top of Acme.

Note: Magi is in alpha development so expect breaking changes!

Installation

  1. Create a new Python virtual environment
python3 -m venv venv
source venv/bin/activate
  1. Install dependencies and the package in editable mode by running
pip install -U pip setuptools wheel
pip install -r requirements.txt # This uses pinned dependencies, you may adjust this for your needs.
pip install -e .

If for some reason installation fails, first check out GitHub Actions badge to see if this fails on the latest CI run. If the CI is successful, then it's likely that there are some issues to setting up your own environment. Refer to .github/workflows/ci.yaml as the official source for how to set up the environment.

Agents

magi includes popular RL algorithm implementation such as SAC, DrQ, SAC-AE and PETS. Refer to magi/agents for a full list of agents.

Examples

Check out magi/examples where we include examples of using our RL agents on popular benchmark tasks.

Testing

On Linux, you can run tests with

JAX_PLATFORM_NAME=cpu pytest -n `grep -c ^processor /proc/cpuinfo` magi

Contributing

Refer to CONTRIBUTING.md.

Acknowledgements

Magi is inspired by many of the open-source RL projects out there. Here is a (non-exhaustive) list of related libraries and packages that Magi references:

License

Apache License 2.0

Citation

If you use Magi in your work, please cite us according to the CITATION file. You may learn more about the CITATION file from here.

Comments
  • PETS initial solution is zero for symmetric bounds

    PETS initial solution is zero for symmetric bounds

    The attempt to normalize the initial solution here causes multiplication by zero if bounds are symmetric about zero. https://github.com/ethanluoyc/magi/blob/646976d04ce9f59498994f71961d1a8dd2e206e9/magi/agents/pets/acting.py#L69-L72

    question wontfix 
    opened by Sicelukwanda 4
  • Adopt a configuration library

    Adopt a configuration library

    Our current examples use different approaches for configuring the RL agents. We can make things more scalable and easier to maintain by using a library for writing configuration.

    There were two options I considered. Hydra and ml_collections

    • ml_collections, https://github.com/google/ml_collections, a configuration library from Google. This is the library adopted by a lot of Google Research + DM projects. The configuration files are just Python modules with a get_config function. It is very non-intrusive and requires minimal changes to the rest of the codebase other than the entry point. Compared to Hydra, it does not provide sweeps or multi-run like functionality out of the box. However, this should be easy to achieve with some metaprogramming that generates the sweeps.
    • hydra, http://hydra.cc/docs/, this is the configuration used by FAIR projects. Hydra is nice since it provides a lot of useful utilities out of the box (e.g., sweeps). Configs are in YAML. I initially thought Hydra would be a good fit. However, after adopting for some of my personal projects, I found tailoring it to specific needs becomes difficult.

    I suppose that we will incrementally move to ml_collections for writing the configuration in the examples. Users of Magi can easily opt-out from ml_collections if they prefer their own approach to configuration.

    Moving on, we should also start thinking about how to enable users to easily sweep with different hyperparameters. I had some experience with some closed-source approaches to doing this and would like to transfer some of that experience to develop a (sub)package for it.

    enhancement 
    opened by ethanluoyc 4
  • Typo in run_drq example.

    Typo in run_drq example.

    Line reference: https://github.com/ethanluoyc/magi/blob/cff26ddb87165bb6e19796dc77521e3191afcffc/magi/examples/run_drq.py#L32

    Default value should be "10000"

    opened by akbir 2
  • Add DrQ-v2 agent

    Add DrQ-v2 agent

    Add a JAX implementation of DrQ-v2 from

    Yarats, D., Fergus, R., Lazaric, A., & Pinto, L. (2021). Mastering visual continuous control: Improved data-augmented reinforcement learning. arXiv preprint arXiv:2107.09645.

    The official PyTorch implementation can be found at

    https://github.com/facebookresearch/drqv2
    

    As with other agents, it uses Reverb as the backend for replay. However, this may be problematic if we are interested in reproducing the results from the paper, which uses a 1M replay since Reverb does not have mechanisms for storing part of the replay on disk. It is worth investigating the actual memory usage for using a 1M replay. It might be feasible to still use this implementation since Reverb also does compression under the hood.

    There are some remaining TODOs, which can be found in magi/agents/drq_v2/README.md. Nevertheless, this implementation should match the original implementation in most of the details.

    The next step would be to benchmark this implementation on the set of tasks used in the original paper. It would also be worthwhile investigating if adding prefetching would speed up the implementation. Right now, this version can run at ~56 FPS on Nvidia 3080. Factoring in the action repeat (of 2), this runs at ~120 FPS in real environment steps.

    Closes #52

    opened by ethanluoyc 1
  • Reward/Costs Naming Convention

    Reward/Costs Naming Convention

    https://github.com/ethanluoyc/magi/blob/4e29dfdeb3d0705ca95f0043ce6485cafab68f0f/magi/agents/pets/models/model.py#L319 In the above line we seem to be adding costs to total_rewards, but the unroll method is expected to return rewards not costs. It would make sense to rename costs to something more apt such as non_terminated_rewards.

    enhancement 
    opened by Sicelukwanda 1
  • Error in running tests on machines with GPU

    Error in running tests on machines with GPU

    __________________________ ERROR collecting magi/agents/drq/agent_test.py ___________________________
    magi/agents/drq/agent_test.py:7: in <module>
        from magi.agents.drq import networks
    magi/agents/drq/networks.py:9: in <module>
        orthogonal_init = hk.initializers.Orthogonal(scale=jnp.sqrt(2.0))
    venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:373: in <lambda>
        fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
    venv/lib/python3.8/site-packages/jax/_src/lax/lax.py:312: in sqrt
        return sqrt_p.bind(x)
    venv/lib/python3.8/site-packages/jax/core.py:259: in bind
        out = top_trace.process_primitive(self, tracers, params)
    venv/lib/python3.8/site-packages/jax/core.py:597: in process_primitive
        return primitive.impl(*tracers, **params)
    venv/lib/python3.8/site-packages/jax/interpreters/xla.py:230: in apply_primitive
        compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    venv/lib/python3.8/site-packages/jax/_src/util.py:197: in wrapper
        return cached(bool(config.x64_enabled), *args, **kwargs)
    venv/lib/python3.8/site-packages/jax/_src/util.py:190: in cached
        return f(*args, **kwargs)
    venv/lib/python3.8/site-packages/jax/interpreters/xla.py:280: in xla_primitive_callable
        compiled = backend_compile(backend, built_c, options)
    venv/lib/python3.8/site-packages/jax/interpreters/xla.py:344: in backend_compile
        return backend.compile(built_c, compile_options=options)
    E   RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
    
    opened by ethanluoyc 1
  • Improve Wandb logging

    Improve Wandb logging

    • [x] run wandb init in the logger instead of outside
    • [x] configure wandb to use step_key for step, similar to the tfsummary logger
    • [ ] handle wandb finish in close function of logger
    opened by ethanluoyc 1
  • [WIP] PETS

    [WIP] PETS

    Add PETS, a model-based RL algorithm based on NN ensembles.

    For #15.

    We still need to sort out the interfaces for MBRL in magi so that they are more extensible. We can draw some inspiration from mbrl-lib from FAIR, but to a limited extent since that is closely tied to PyTorch. Another inspiration is to follow Acme's MCTS agent implementation. However, there are some nuances in implementing that interface with JAX as we would like to JIT the entire model environment, which requires explicit wiring of the parameters due to JAX's funcitonal nature.

    As of 22f5f08, the agent can get reasonable behaviour on the cartpole task used in the original paper (which I ported), but we need to benchmark more thoroughly. For this, I managed to implement CEM with MPC (in the style of mbrl-lib instead of the original paper, which introduces momentum and does the sampling in the stadand normal instead of in the space of the action space) (They are equivalent with the exception with handling of the min variance.)

    @Sicelukwanda may be of interest.

    [no ci]

    TODOs

    • [ ] Support early stopping with validation dataset
    • [x] Add additional losses for the min/max logvar trick used in the original paper
    • [ ] Support trajectory sampling besides TSInf. (IIRC, this is what I implement, which propagates the particles with the same network in the ensemble during rollouts.
    • [x] Normalization of the data is not performed right now, but maybe it is useful.
    opened by ethanluoyc 1
  • PETS

    PETS

    We would love to have a good implementation of PETS in the paper.

    Kurtland Chua, Roberto Calandra, Rowan McAllister, Sergey Levine, Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamics Models, NIPS 2018, arxiv:1805.12114

    I will put up a WIP PR for tracking the status of a proof of concept implementation, we can then iterate over the design, and test it against the environments used by the original paper.

    opened by ethanluoyc 1
  • Support MacOS

    Support MacOS

    Currently, the agents do not work on MacOS because of the dependency on dm-reverb, which currently only supports Linux.

    Fortunately, there is an open PR in reverb https://github.com/deepmind/reverb/pull/24 that aims to add support for MacOS and it looks like it is close to being merged. We can publicize support for MacOS when that is merged.

    wontfix 
    opened by ethanluoyc 1
  • Add distributed IMPALA.

    Add distributed IMPALA.

    This adds both a single-process and distributed (multiprocessing) implementation of IMPALA. The distributed implementation is written with Launchpad https://github.com/deepmind/launchpad.

    opened by ethanluoyc 1
  • WIP: F/benchmark

    WIP: F/benchmark

    • [x] Add bsuite example for IMPALA + catch
    • [ ] Add bsuite script for generating sweeps over multiple tasks
    • [ ] Add collab / script for report generating
    opened by akbir 0
  • Migrate agents to layouts.

    Migrate agents to layouts.

    The next release of Acme will include layouts

    https://github.com/deepmind/acme/blob/master/acme/jax/layouts/local_layout.py.

    We should migrate our off-policy agents to follow something similar. This would give improved modularity for defining agents.

    List agents waiting for migration

    • [x] SAC
    • [x] TD3
    • [x] CRR
    • [x] DrQ
    • [ ] SAC-AE
    • [x] TD3-BC
    • [ ] PETS

    Note: the IMPALA agent already has a local definition of layouts. We should make that more general and unify that with the other agents.

    opened by ethanluoyc 0
  • Offline RL agents

    Offline RL agents

    Wishlist of algorithms

    • [x] TD3-BC #46
    • [ ] BEAR
    • [ ] BCQ
    • [x] CQL #56
    • [ ] ABM
    • [ ] BRAC
    • [x] CRR https://github.com/ethanluoyc/magi/pull/49
    • [x] IQL #57
    • [ ] AWAC #47
    opened by ethanluoyc 2
  • Moving the cost function inside the OptimizerBasedActor

    Moving the cost function inside the OptimizerBasedActor

    Given that different OptimizerBasedActor subclasses may need different cost functions (e.g. scalar-valued and vector-valued), maybe we should try making the cost function a method of the OptimizerBasedActor in magi/agents/pets/acting.py. This way users can overwrite this method to implement their custom cost functions.

    enhancement 
    opened by Sicelukwanda 1
Owner
Yicheng Luo
PhD student at UCL AI Center. Former intern at @deepmind and @secondmind-labs.
Yicheng Luo
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

null 128 Dec 27, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
JAX-based neural network library

Haiku: Sonnet for JAX Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku What is Haiku? Haiku i

DeepMind 2.3k Jan 4, 2023
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 2, 2021
Scenic: A Jax Library for Computer Vision and Beyond

Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c

Google Research 1.6k Dec 27, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 27, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
An example project demonstrating how the Autonomous Learning Library can be used to build new reinforcement learning agents.

About This repository shows how Autonomous Learning Library can be used to build new reinforcement learning agents. In particular, it contains a model

Chris Nota 5 Aug 30, 2022
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Alex Rogozhnikov 6.2k Jan 1, 2023
Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Google 729 Jan 2, 2023
A Pytree Module system for Deep Learning in JAX

Treex A Pytree-based Module system for Deep Learning in JAX Intuitive: Modules are simple Python objects that respect Object-Oriented semantics and sh

Cristian Garcia 216 Dec 20, 2022
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022