PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

Overview

pipeline status Documentation Status coverage report codestyle

Stable Baselines3

Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of Stable Baselines.

You can read a detailed presentation of Stable Baselines3 in the v1.0 blog post.

These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.

Note: despite its simplicity of use, Stable Baselines3 (SB3) assumes you have some knowledge about Reinforcement Learning (RL). You should not utilize this library without some practice. To that extent, we provide good resources in the documentation to get started with RL.

Main Features

The performance of each algorithm was tested (see Results section in their respective page), you can take a look at the issues #48 and #49 for more details.

Features Stable-Baselines3
State of the art RL methods ✔️
Documentation ✔️
Custom environments ✔️
Custom policies ✔️
Common interface ✔️
Ipython / Notebook friendly ✔️
Tensorboard support ✔️
PEP8 code style ✔️
Custom callback ✔️
High code coverage ✔️
Type hints ✔️

Planned features

Please take a look at the Roadmap and Milestones.

Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

A migration guide from SB2 to SB3 can be found in the documentation.

Documentation

Documentation is available online: https://stable-baselines3.readthedocs.io/

RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents

RL Baselines3 Zoo is a training framework for Reinforcement Learning (RL).

It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.

In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.

Goals of this repository:

  1. Provide a simple interface to train and enjoy RL agents
  2. Benchmark the different Reinforcement Learning algorithms
  3. Provide tuned hyperparameters for each environment and RL algorithm
  4. Have fun with the trained agents!

Github repo: https://github.com/DLR-RM/rl-baselines3-zoo

Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html

SB3-Contrib: Experimental RL Features

We implement experimental features in a separate contrib repository: SB3-Contrib

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN).

Documentation is available online: https://sb3-contrib.readthedocs.io/

Installation

Note: Stable-Baselines3 supports PyTorch 1.4+.

Prerequisites

Stable Baselines3 requires python 3.6+.

Windows 10

To install stable-baselines on Windows, please look at the documentation.

Install using pip

Install the Stable Baselines3 package:

pip install stable-baselines3[extra]

This includes an optional dependencies like Tensorboard, OpenCV or atari-py to train on atari games. If you do not need those, you can use:

pip install stable-baselines3

Please read the documentation for more details and alternatives (from source, using docker).

Example

Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.

Here is a quick example of how to train and run PPO on a cartpole environment:

import gym

from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

env.close()

Or just train a model with a one liner if the environment is registered in Gym and if the policy is registered:

from stable_baselines3 import PPO

model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)

Please read the documentation for more examples.

Try it online with Colab Notebooks !

All the following examples can be executed online using Google colab notebooks:

Implemented Algorithms

Name Recurrent Box Discrete MultiDiscrete MultiBinary Multi Processing
A2C ✔️ ✔️ ✔️ ✔️ ✔️
DDPG ✔️
DQN ✔️
HER ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️
TD3 ✔️

Actions gym.spaces:

  • Box: A N-dimensional box that containes every point in the action space.
  • Discrete: A list of possible actions, where each timestep only one of the actions can be used.
  • MultiDiscrete: A list of possible actions, where each timestep only one action of each discrete set can be used.
  • MultiBinary: A list of possible actions, where each timestep any of the actions can be used in any combination.

Testing the installation

All unit tests in stable baselines3 can be run using pytest runner:

pip install pytest pytest-cov
make pytest

You can also do a static type check using pytype:

pip install pytype
make type

Codestyle check with flake8:

pip install flake8
make lint

Projects Using Stable-Baselines3

We try to maintain a list of project using stable-baselines3 in the documentation, please tell us when if you want your project to appear on this page ;)

Citing the Project

To cite this repository in publications:

@misc{stable-baselines3,
  author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
  title = {Stable Baselines3},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}

Maintainers

Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @ernestum), Adam Gleave (@AdamGleave) and Anssi Kanervisto (@Miffyli).

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

How To Contribute

To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.

Acknowledgments

The initial work to develop Stable Baselines3 was partially funded by the project Reduced Complexity Models from the Helmholtz-Gemeinschaft Deutscher Forschungszentren.

The original version, Stable Baselines, was created in the robotics lab U2IS (INRIA Flowers team) at ENSTA ParisTech.

Logo credits: L.M. Tenkes

Comments
  • Support for MultiBinary / MultiDiscrete spaces

    Support for MultiBinary / MultiDiscrete spaces

    Description

    • Added support for MultiDiscrete and MultiBinary observation / action spaces for PPO and A2C
    • Added MultiCategorical and Bernoulli distributions
    • Added tests for MultiCategorical and Bernoulli distributions and actions spaces

    Motivation and Context

    • [x] I have raised an issue to propose this change (required for new features and bug fixes)

    closes #5 closes #4

    Types of changes

    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [x] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)
    • [ ] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [ ] My change requires a change to the documentation.
    • [x] I have updated the tests accordingly (required for a bug fix or a new feature).
    • [ ] I have updated the documentation accordingly.
    • [x] I have checked the codestyle using make lint
    • [x] I have ensured pytest and pytype both pass.
    opened by rolandgvc 54
  • Roadmap to Stable-Baselines3 V1.0

    Roadmap to Stable-Baselines3 V1.0

    This issue is meant to be updated as the list of changes is not exhaustive

    Dear all,

    Stable-Baselines3 beta is now out :tada: ! This issue is meant to reference what is implemented and what is missing before a first major version.

    As mentioned in the README, before v1.0, breaking changes may occur. I would like to encourage contributors (especially the maintainers) to make comments on how to improve the library before v1.0 (and maybe make some internal changes).

    I will try to review the features mentioned in https://github.com/hill-a/stable-baselines/issues/576 (and https://github.com/hill-a/stable-baselines/issues/733) and I will create issues soon to reference what is missing.

    What is implemented?

    • [x] basic features (training/saving/loading/predict)
    • [x] basic set of algorithms (A2C/PPO/SAC/TD3)
    • [x] basic pre-processing (Box and Discrete observation/action spaces are handled)
    • [x] callback support
    • [x] complete benchmark for the continuous action case
    • [x] basic rl zoo for training/evaluating plotting (https://github.com/DLR-RM/rl-baselines3-zoo)
    • [x] consistent api
    • [x] basic tests and most type hints
    • [x] continuous integration (I'm in discussion with the organization admins for that)
    • [x] handle more observation/action spaces #4 and #5 (thanks @rolandgvc)
    • [x] tensorboard integration #9 (thanks @rolandgvc)
    • [x] basic documentation and notebooks
    • [x] automatic build of the documentation
    • [x] Vanilla DQN #6 (thanks @Artemis-Skade)
    • [x] Refactor off-policy critics to reduce code duplication #3 (see #78 )
    • [x] DDPG #3
    • [x] do a complete benchmark for the discrete case #49 (thanks @Miffyli !)
    • [x] performance check for continuous actions #48 (even better than gSDE paper)
    • [x] get/set parameters for the base class (#138 )
    • [x] clean up type-hints in docs #10 (cumbersome to read)
    • [x] documenting the migration between SB and SB3 #11
    • [x] finish typing some methods #175
    • [x] HER #8 (thanks @megan-klaiber)
    • [x] finishing to update and clean the doc #166 (help is wanted)
    • [x] finishing to update the notebooks and the tutorial #7 (I will do that, only HER notebook missing)

    What are the new features?

    • [x] much cleaner base code (and no more warnings =D )
    • [x] independent saving/loading/predict for policies
    • [x] State-Dependent Exploration (SDE) for using RL directly on real robots (this is a unique feature, it was the starting point of SB3, I published a paper on that: https://arxiv.org/abs/2005.05719)
    • [x] proper evaluation (using separate env) is included in the base class (using EvalCallback)
    • [x] all environments are VecEnv
    • [x] better saving/loading (now can include the replay buffer and the optimizers)
    • [x] any number of critics are allowed for SAC/TD3
    • [x] custom actor/critic net arch for off-policy algos (#113 )
    • [x] QR-DQN in SB3-Contrib
    • [x] Truncated Quantile Critics (TQC) (see #83 ) in SB3-Contrib
    • @Miffyli suggested a "contrib" repo for experimental features (it is here)

    What is missing?

    • [x] syncing some files with Stable-Baselines to remain consistent (we may be good now, but need to be checked)
    • [x] finish code-review of exisiting code #17

    Checklist for v1.0 release

    • [x] Update Readme
    • [x] Prepare blog post
    • [x] Update doc: add links to the stable-baselines3 contrib
    • [x] Update docker image to use newer Ubuntu version
    • [x] Populate RL zoo

    What is next? (for V1.1+)

    • basic dict/tuple support for observations (#243 )
    • simple recurrent policies? (https://github.com/DLR-RM/stable-baselines3/issues/18)
    • DQN extensions (double, PER, IQN) (https://github.com/DLR-RM/stable-baselines3/issues/622)
    • Implement TRPO (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/40)
    • multi-worker training for all algorithms (#179 )
    • n-step returns for off-policy algorithms #47 (@PartiallyTyped )
    • SAC discrete #157 (need to be discussed, benefit vs DQN+extensions?)
    • Energy Based Prioritisation? (@RyanRizzo96)
    • implement action_proba in the base class?
    • test the doc snippets #14 (help is welcomed)
    • noisy networks (https://arxiv.org/abs/1706.10295) @PartiallyTyped ? exploration in parameter space? (https://github.com/DLR-RM/stable-baselines3/issues/622)
    • Munchausen Reinforcement Learning (MDQN) (probably in the contrib first, e.g. https://github.com/pfnet/pfrl/pull/74)

    side note: should we change the default start_method to fork? (now that we don't have tf anymore)

    enhancement 
    opened by araffin 46
  • Dictionary Observations

    Dictionary Observations

    In machine learning, input comes in the form of matrices. Typically, models take in 1 matrix at a time, such as in the case of image classification where matrix containing the input image is given to the model and the model classifies the image. However, there are many situations in which taking multiple inputs is necessary. One example is when training a reinforcement learning agent and the observations that the agent sees comes in the form of an image (e.g., camera, grid sensor, etc) and a vector describing the agent's state (e.g., current position, health, etc). In this situation, it is necessary to feed 2 inputs to the model. This PR addresses this.

    Description

    • added example environments with multi-input observations
    • added DictReplayBuffer and DictRolloutBuffer to handle dictionaries
    • added CombinedExtractor feature extractor that handles generic dictionary data
    • added StackedObservations and StackedDictObservations to decouple data stacking from the VecFrameStack wrapper
    • added test_dict_env.py test
    • added a is_vectorized_env() method per observation space type in common\utils.py

    Motivation and Context

    • [x] I have raised an issue to propose this change (link)

    closes #216

    closes #287 (image support for HER) closes #284 (for off-policy algorithms)

    Types of changes

    • [x] New feature (non-breaking change which adds functionality)
    • [x] Breaking change (fix or feature that would cause existing functionality to change)
    • [x] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [x] My change requires a change to the documentation.
    • [x] I have updated the tests accordingly
    • [x] I have updated the documentation accordingly.
    • [x] I have reformatted the code using make format
    • [x] I have checked the codestyle using make check-codestyle and make lint
    • [x] I have ensured make pytest and make type both pass. (required)
    • [x] I have checked that the documentation builds using make doc (required)

    TODOs

    • [x] check that documentation is properly updated
    • [x] check that dict with vectors only is the same as mlp policy + vector flatten
    • [x] Update env checker
    • [x] (optional) refactor HER: https://github.com/DLR-RM/stable-baselines3/tree/feat/refactor-her
    • [x] test A2C/PPO/SAC alone with GoalEnv
    opened by J-Travnik 45
  • Match performance with stable-baselines (discrete case)

    Match performance with stable-baselines (discrete case)

    This PR will be done when stable-baselines3 agent performance matches stable-baselines in discrete envs. Will be tested on discrete control tasks and Atari environments.

    Closes #49 Closes #105

    PS: Sorry about the confusing branch-name.

    Changes

    • Fix storing correct dones (#105, credits to AndyShih12)
    • Fix number of filters in NatureCNN
    • Add common.sb2_compat.RMSpropTFLike, which is a modification of RMSprop that matches TF version, and is required for matching performance in A2C.

    TODO

    • [x] Match performance of A2C and PPO.

    • [x] A2C Cartpole matches (mostly, see this. Averaged over 10 random seeds for both. Requires the TF-like RMSprop, and even still in the very end SB3 seems more unstable.)

    • [x] A2C Atari matches (mostly, see sb2 and sb3. Original sb3 result here. Three random seeds, each line separate run (ignore legend). Using TF-like RMSprop. Performance and stability mostly matches, except sb2 has sudden spike in performance in Q*Bert. Something to do with stability in distributions?)

    • [x] PPO Cartpole (using rl-zoo parameters, see learning curves, averaged over 20 random seeds)

    • [x] PPO Atari (mostly, see sb2 and sb3 results (shaded curves averaged over two seeds). Q*Bert still seems to have an edge on SB2 for unknown reasons)

    • [x] Check and match performance of DQN. Seems ok. See following learning curves, each curve is an average over three random seeds: atari_spaceinvaders.pdf atari_qbert.pdf atari_breakout.pdf atari_pong.pdf

    • [x] Check if "dones" fix can (and should) be moved to computing GAE side.

    • [x] ~~Write docs on how to match A2C and PPO settings to stable-baselines ("moving from stable-baselines"). There are some important quirks to note here.~~ Move this to migration guide PR #123 .

    Types of changes

    • [x] Bug fix (non-breaking change which fixes an issue)
    • [ ] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)
    • [x] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [ ] My change requires a change to the documentation.
    • [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
    • [ ] I have updated the documentation accordingly.
    • [x] I have reformatted the code using make format (required)
    • [x] I have checked the codestyle using make check-codestyle and make lint (required)
    • [x] I have ensured make pytest and make type both pass. (required)
    opened by Miffyli 28
  • [Question/Discussion] Comparing stable-baselines3 vs stable-baselines

    [Question/Discussion] Comparing stable-baselines3 vs stable-baselines

    Did anybody compare the training speed (or other performance metrics) of SB and SB3 for the implemented algorithms (e.g., PPO?) Is there a reason to prefer either one for developing a new project?

    question 
    opened by AlessandroZavoli 28
  • Tensorboard integration

    Tensorboard integration

    Description

    Adding support for logging to tensorboard.

    Missing:

    • [x] Documentation
    • [x] More tests
    • [x] check we don't make the same mistakes as SB2 (https://github.com/hill-a/stable-baselines/issues/855 https://github.com/hill-a/stable-baselines/issues/56 )

    Motivation and Context

    • [x] I have raised an issue to propose this change (required for new features and bug fixes)

    closes #9

    Types of changes

    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [x] New feature (non-breaking change which adds functionality)
    • [x] Breaking change (fix or feature that would cause existing functionality to change)
    • [ ] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [x] My change requires a change to the documentation.
    • [x] I have updated the tests accordingly (required for a bug fix or a new feature).
    • [x] I have updated the documentation accordingly.
    • [x] I have checked the codestyle using make lint
    • [x] I have ensured make pytest and make type both pass.
    opened by rolandgvc 27
  • Implement DQN

    Implement DQN

    Description

    Implementation of vanilla dqn

    closes #6 closes #37 closes #46

    Missing:

    • [x] Update examples to include DQN
    • [x] Add test for replay buffer truncation

    Motivation and Context

    • [x] I have raised an issue to propose this change (required for new features and bug fixes)

    Types of changes

    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [x] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)
    • [ ] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [x] My change requires a change to the documentation.
    • [x] I have updated the tests accordingly (required for a bug fix or a new feature).
    • [x] I have updated the documentation accordingly.
    • [x] I have checked the codestyle using make lint
    • [x] I have ensured make pytest and make type both pass.
    opened by ndormann 26
  • Memory allocation for buffers

    Memory allocation for buffers

    With the current implementation of buffers.py one can request a buffersize which doesn't fit in the memory provided but because of numpys implementation of np.zeros() the memory is not allocated before it is actually used. But because the buffer is meant to be filled completely (otherwise one could just use a smaller buffer) the computer will finally run out of memory and start to swap heavily. Because there are only smaller parts of the buffer that are accessed at once (minibatches) the system will just swap the necessary pages in and out of memory. At that moment the progress of the run is most likely lost and one has to start a new run with a smaller buffer.

    I would recommend using np.ones instead, as it will allocate the buffer at the beginning and fail if there is not enough memory provided by the system. The only issue is that there is no clear error description in the case where the system memory is exceeded but python gets simply killed by the OS with a SIGKILL. Maybe one could catch that command?

    bug enhancement 
    opened by ndormann 25
  • [Bug]: Can't create the PPO model on Macbook M1

    [Bug]: Can't create the PPO model on Macbook M1

    🐛 Bug

    I'm using Macbook Pro M1 and running code in Jupyter Notebook. Every time I run the scripts, I get an error and the Kernel crashed. I've tried running the code in Google Colab and it went well. It seems to me that stable-baselines3 library doesn't support M1 chips. What do I need to do now to run the code locally?

    To Reproduce

    
    from stable_baselines3 import PPO
    from stable_baselines3.common.env_util import make_vec_env
    
    # Parallel environments
    env = make_vec_env("CartPole-v1", n_envs=4)
    
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=25000)
    model.save("ppo_cartpole")
    
    del model # remove to demonstrate saving and loading
    
    model = PPO.load("ppo_cartpole")
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    
    

    Relevant log output / Error message

    The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details.
    
    Canceled future for execute_request message before replies were done
    

    System Info

    No response

    Checklist

    • [x] I have checked that there is no similar issue in the repo
    • [X] I have read the documentation
    • [X] I have provided a minimal working example to reproduce the bug
    • [X] I've used the markdown code blocks for both code and stack traces.
    bug 
    opened by nguyenhongson1902 24
  • Fix return type for load, learn in BaseAlgorithm

    Fix return type for load, learn in BaseAlgorithm

    Description

    Fixes the return type of .load() and .learn() methods in BaseAlgorithm so that they now use the Self type PEP 0673 instead of BaseAlgorithm, which breaks type checkers for use with any subclass.

    Motivation and Context

    Closes #1040.

    • [x] I have raised an issue to propose this change (required for new features and bug fixes)

    Types of changes

    • [x] Bug fix (non-breaking change which fixes an issue)
    • [ ] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)
    • [ ] Documentation (update in the documentation)

    Checklist:

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [ ] My change requires a change to the documentation.
    • [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
    • [ ] I have updated the documentation accordingly.
    • [x] I have reformatted the code using make format (required)
    • [x] I have checked the codestyle using make check-codestyle and make lint (required)
    • [x] I have ensured make pytest and make type both pass. (required)
    • [x] I have checked that the documentation builds using make doc (required)

    Note: You can run most of the checks using make commit-checks.

    Note: we are using a maximum length of 127 characters per line

    opened by Rocamonde 21
  • [feature request] Add total episodes parameter to model learn method

    [feature request] Add total episodes parameter to model learn method

    Hi,

    TL;DR: I would like the option to pass either a total_episodes parameter or a total_timesteps to the model.learn() method.

    Now, for my reasoning. Currently, we can only define the total_timesteps when training an agent, as follows.

    model = A2C('MlpLstmPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
    model.learn(total_timesteps=1000)
    

    However, for some scenarios (e.g., stock trading), it is quite common to have a fixed number of timesteps per episode, given by the available time-series data points. Also, it can be quite valuable to scan all data points an equal amount of time thoroughly and to determine the number of passes, which is defined by the number of episodes.

    Thus, to train for a given number of episodes for a fixed number of timesteps, I have to get the total_timesteps value, before passing it to method model.learn() as follows:

    desired_total_episodes = 100
    n_points = train_df.shape[0]) # get the number of data points
    total_timesteps = desired_total_episodes * n_points
    

    Even so, this answer on StackOverflow says that

    Where the episode length is known, set it to the desired number of episode you would like to train. However, it might be less because the agent might not (probably wont) reach max steps every time.

    I must admit I do not know how accurate this answer, but this worries me that my model may not scan all the data equally.

    Another option, as discussed in this issue from previous stable baseline repo, is to use a callback function. Still, for this callback approach, I would have to pass a total_timesteps variable that is high enough so that I can have the desired number of episodes. Hence, this callback approach seems like an out of the way workaround.

    In conclusion, I believe that including the option to pass a total_episodes could be a simple and effective approach that would broaden the number of use cases attended by this project.

    Thank you for your attention!

    enhancement 
    opened by xicocaio 21
  • Ensure `ProgressBarCallback` *better* matches *actual* `total_timesteps`

    Ensure `ProgressBarCallback` *better* matches *actual* `total_timesteps`

    Description

    Added logic to ProgressBarCallback that gives a better estimate of the total number of timesteps that an algorithm processes, allowing the progress bar to more accurately report the remaining proportion of training steps.

    Note, environments with variable length episodes may still result in over/under-reported progress (potentially motivating the addition of an alternative total_episodes argument to .learn(...), to compliment total_timesteps, which ProgressBarCallback could report the proportion of episodes remaining instead?) Likely out-of-scope of this PR...

    Also, the logic for estimating the number of training timesteps might be better moved into OffPolicyAlgorithm._setup_learn and/or OnPolicyAlgorithm._setup_learn? It looks like this method already adjusts a provided total_timesteps for previously completed training, so could also round-up to the nearest update_freq/n_steps size? The function looks like it will need a bit of reordering though, so that the estimate can be calculated before the ProgressBarCallback is initialised in BaseAlgorithm._setup_learn...

    Motivation and Context

    Closes Issue #1259.

    • [x] I have raised an issue to propose this change (required for new features and bug fixes)

    Types of changes

    • [x] Bug fix (non-breaking change which fixes an issue)
    • [ ] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)
    • [ ] Documentation (update in the documentation)

    Checklist

    • [x] I've read the CONTRIBUTION guide (required)
    • [x] I have updated the changelog accordingly (required).
    • [ ] My change requires a change to the documentation.
    • [ ] I have updated the tests accordingly (required for a bug fix or a new feature). See https://github.com/DLR-RM/stable-baselines3/issues/1259#issuecomment-1370202824
    • [ ] I have updated the documentation accordingly.
    • [ ] I have opened an associated PR on the SB3-Contrib repository (if necessary)
    • [ ] I have opened an associated PR on the RL-Zoo3 repository (if necessary)
    • [x] I have reformatted the code using make format (required)
    • [x] I have checked the codestyle using make check-codestyle and make lint (required)
    • [x] I have ensured make pytest and make type both pass. (required)
    • [x] I have checked that the documentation builds using make doc (required)

    Note: You can run most of the checks using make commit-checks. Note: we are using a maximum length of 127 characters per line

    opened by dominicgkerr 0
  • [Bug]: `ProgressBarCallback` doesn't match *actual* `total_timesteps`

    [Bug]: `ProgressBarCallback` doesn't match *actual* `total_timesteps`

    🐛 Bug

    In a similar way to Issue #1150, a ProgressBarCallback can be stepped more/fewer times than the total_timesteps value provided to an algorithm's .learn(...) method. Screenshot from 2023-01-03 16-46-34

    As tqdm.rich doesn't always flush when self.pbar.update is called, ProgressBarCallback seems to typically under-report the number of steps actually processed (as seen above). However, this can be easily fixed by adding a self.pbar.refresh() call to the ProgressBarCallback._on_training_end method. Screenshot from 2023-01-03 18-22-17

    To fix the over-reporting of the final (PPO) progress bar, I think the total value passed into the callback's constructor (file: stable_baselines3/common/callbacks.py, line: 680) should be first rounded-up to a multiple of
    the algorithm's n_steps. This should then correctly report the algorithms progress while collecting experience - without this, I can imagine a very large rollout or computationally slow environment, resulting in a long wait without useful progress updates...

    More generally, for vectorized environments, I think the total should also be rounded-up to a multiple of the callback's training_env.num_envs, as its ._on_step() method advances the displayed progress by more than one. I believe this would be a one/two line change (plus, maybe a new unittest), which I'd be very happy to open!

    To Reproduce

    import gym
    from stable_baselines3 import DQN, A2C, PPO
    
    env = gym.make('MountainCar-v0')
    
    DQN("MlpPolicy", env, seed=42).learn(total_timesteps=10_000, progress_bar=True) 
    A2C("MlpPolicy", env, seed=42).learn(total_timesteps=10_000, progress_bar=True)
    PPO("MlpPolicy", env, seed=42).learn(total_timesteps=10_000, progress_bar=True)
    

    Relevant log output / Error message

    No response

    System Info

    • OS: Linux-5.15.0-56-generic-x86_64-with-glibc2.35 # 62-Ubuntu SMP Tue Nov 22 19:54:14 UTC 2022
    • Python: 3.10.6
    • Stable-Baselines3: 1.7.0a11
    • PyTorch: 1.13.0+cu117
    • GPU Enabled: True
    • Numpy: 1.23.5
    • Gym: 0.21.0

    Checklist

    • [X] I have checked that there is no similar issue in the repo
    • [X] I have read the documentation
    • [X] I have provided a minimal working example to reproduce the bug
    • [X] I've used the markdown code blocks for both code and stack traces.
    bug 
    opened by dominicgkerr 1
  • [Feature Request] Implementation of QC_SANE

    [Feature Request] Implementation of QC_SANE

    🚀 Feature

    I have implemented QC_SANE DRL approach (pytorch framework) and want to add this as feature to this repo. Research Paper: https://ieeexplore.ieee.org/document/9640528

    Motivation

    To contribute to open source (stable baselines3) as this repo and its tensorflow implementations helped me a lot in understanding the implementation of DRL approaches.

    Pitch

    I want to provide the implementation of QC_SANE so that DRL community can leverage its implementation.

    Alternatives

    I have only pytorch implementation of QC_SANE. No other alternatives are available.

    Additional context

    No response

    Checklist

    • [X] I have checked that there is no similar issue in the repo
    enhancement 
    opened by surbhigupta1908 2
  • [Question] Resuming training of normalized environment decreases dramatically

    [Question] Resuming training of normalized environment decreases dramatically

    ❓ Question

    Hello everybody!

    I have trained Humanoid-v3 in Google Colabs and the average episode reward looked really good. But when I want to test the agent locally, it does not perform well at all. And resuming the training on my local computer shows a HUGE decrease in average episode reward 4000 down to 1000.

    I think I did it just like in this example: Pybullet with normalized features

    To show you the most important parts:

    In Colab, I trained the agent this way:

    env = SubprocVecEnv([lambda: gym.make('Humanoid-v3')]*16)
    env = VecNormalize(env, norm_obs=True, norm_reward=False)
    env = VecMonitor(env)
    
    # Save a checkpoint every 10000 steps
    checkpoint_callback = CheckpointCallback(
        save_freq=10000,
        save_path="drive/MyDrive/Humanoid/checkpoints",
        save_vecnormalize=True
    )
    
    # Learning rate scheduler
    def linear_schedule(initial_value: float):
        def func(progress_remaining: float):
            return progress_remaining * initial_value
        return func
    
    model = PPO(
        "MlpPolicy", 
        env,
        learning_rate=0.0001,
        verbose=0,
        tensorboard_log="drive/MyDrive/Humanoid/tensorboard/"
    )
    
    model.learn(
        total_timesteps=10_000_000,
        callback=checkpoint_callback
    )
    

    When I resumed the training on my local computer, I did the following:

    env = SubprocVecEnv([lambda: gym.make('Humanoid-v3')]*16)
    env = VecNormalize.load("checkpoints/rl_model_vecnormalize_10000000_steps.pkl", env)
    env = VecMonitor(env)
    
    env.norm_obs = True
    env.norm_reward = False
    env.training = True
    
    model = PPO.load(
        "checkpoints/rl_model_10000000_steps.zip",
        env=env,
        custom_objects={
            "tensorboard_log": "tensorboard"
        }
    )
    
    model.learn(
        total_timesteps=2_000_000,
        reset_num_timesteps=False,
        callback=checkpoint_callback
    )
    

    This is, how the curve of my average episode looks like after resuming: Learning decrease

    Does anybody have an idea what is going wrong here?

    Checklist

    • [X] I have checked that there is no similar issue in the repo
    • [X] I have read the documentation
    • [X] If code there is, it is minimal and working
    • [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
    question 
    opened by mindacrobatic 1
  • [Question] can you continue training on test examples?

    [Question] can you continue training on test examples?

    ❓ Question

    [maybe this is actually a feature request]

    I'm interested in the setting where the model can continue to update its weights based on its experience at test-time. The model.predict function does not seem to support this, and the model.learn function does not seem to allow passing in the specific observations to learn from (e.g. those returned by the predict function) but rather generates new ones.

    Is there any chance I'm missing the way to utilize the returned values from model.predict in order to continue updating the model weights as I go?

    Checklist

    • [X] I have checked that there is no similar issue in the repo
    • [X] I have read the documentation
    • [X] If code there is, it is minimal and working
    • [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
    question trading warning 
    opened by alexholdenmiller 3
  • Add Examples Combining Useful Tools For Typical Research Workflows

    Add Examples Combining Useful Tools For Typical Research Workflows

    📚 Documentation

    I've recognized that your RL framework offers lots of nice tools. Nothing is really missing. But the thing is, you're not particularly good at giving a newcomer-friendly quickstart. Even for me with a few years of experience in RL and DL, I had to search several hours to find out that you have all kinds of useful callbacks to back up the model, plot TensorBoard graphs, etc.

    You've got all those great tools built into SB3, but people don't know because you don't show them how to combine them in a meaningful way. So IMO, you need some kind of sufficiently full-blown example that people can copy, change some lines and then they're good (not like those examples in your docs that outline specific features in isolation). For me, such an example should include a training progress display, model checkpoints, useful evaluation metrics, a live metrics display, proper setup of algorithms concerning hparams / scalability, etc.

    Let's think about how a researcher would actually use your RL framework. As a researcher you'd probably get really nervous when you don't see any progress especially when the training takes several days, so put a progress bar. Also, it's really important to see whether the training is going as expected, so put a live metrics display allowing for early abortion of failed experiments. And even if you see the reward graph, it's still an issue when you have to admit "the results were great at timestep x but dropped to a result that's worse than random towards the end, so the trained model is basically crap." Of course you want model checkpoints of some kind to minimize the risk of wasted training efforts. Finally when it comes to evaluating the research outcome, maybe you don't want to say "damn the training went really well, but I don't know how I got there". So sure enough you'll need to record all kinds of metrics/logs/configs per experiment to track down which setup mapped to which outcome.

    The outlined example is nothing too special, just basic stuff to support the everyday work of a researcher. Maybe you should expand your thinking to facilitate a better research experience where people can concentrate on the actual research.

    Checklist

    • [X] I have checked that there is no similar issue in the repo
    • [X] I have read the documentation
    documentation 
    opened by Bonifatius94 3
Releases(v1.6.2)
  • v1.6.2(Oct 10, 2022)

    SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3: https://github.com/DLR-RM/rl-baselines3-zoo

    New Features:

    • Added progress_bar argument in the learn() method, displayed using TQDM and rich packages
    • Added progress bar callback

    RL Zoo3

    • The RL Zoo can now be installed as a package (pip install rl_zoo3)

    Bug Fixes:

    • self.num_timesteps was initialized properly only after the first call to on_step() for callbacks
    • Set importlib-metadata version to ~=4.13 to be compatible with gym=0.21

    Deprecations:

    • Added deprecation warning if parameters eval_env, eval_freq or create_eval_env are used (see #925) (@tobirohrer)

    Others:

    • Fixed type hint of the env_id parameter in make_vec_env and make_atari_env (@AlexPasqua)

    Documentation:

    • Extended docstring of the wrapper_class parameter in make_vec_env (@AlexPasqua)
    Source code(tar.gz)
    Source code(zip)
  • v1.6.1(Sep 29, 2022)

    SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

    Breaking Changes:

    • Switched minimum tensorboard version to 2.9.1

    New Features:

    • Support logging hyperparameters to tensorboard (@timothe-chaumont)
    • Added checkpoints for replay buffer and VecNormalize statistics (@anand-bala)
    • Added option for Monitor to append to existing file instead of overriding (@sidney-tio)
    • The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys

    SB3-Contrib

    • Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with RecurrentPPO (@mlodel)

    Bug Fixes:

    • Fixed issue where PPO gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)
    • Fixed the issue that predict does not always return action as np.ndarray (@qgallouedec)
    • Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
    • Added multidimensional action space support (@qgallouedec)
    • Fixed missing verbose parameter passing in the EvalCallback constructor (@burakdmb)
    • Fixed the issue that when updating the target network in DQN, SAC, TD3, the running_mean and running_var properties of batch norm layers are not updated (@honglu2875)
    • Fixed incorrect type annotation of the replay_buffer_class argument in common.OffPolicyAlgorithm initializer, where an instance instead of a class was required (@Rocamonde)
    • Fixed loading saved model with different number of envrionments
    • Removed forward() abstract method declaration from common.policies.BaseModel (already defined in torch.nn.Module) to fix type errors in subclasses (@Rocamonde)
    • Fixed the return type of .load() and .learn() methods in BaseAlgorithm so that they now use TypeVar (@Rocamonde)
    • Fixed an issue where keys with different tags but the same key raised an error in common.logger.HumanOutputFormat (@Rocamonde and @AdamGleave)

    Others:

    • Fixed DictReplayBuffer.next_observations typing (@qgallouedec)
    • Added support for device="auto" in buffers and made it default (@qgallouedec)
    • Updated ResultsWriter` (used internally byMonitorwrapper) to automatically create missing directories whenfilename`` is a path (@dominicgkerr)

    Documentation:

    • Added an example of callback that logs hyperparameters to tensorboard. (@timothe-chaumont)
    • Fixed typo in docstring "nature" -> "Nature" (@Melanol)
    • Added info on split tensorboard logs into (@Melanol)
    • Fixed typo in ppo doc (@francescoluciano)
    • Fixed typo in install doc(@jlp-ue)
    • Clarified and standardized verbosity documentation
    • Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
    • Fixed typos (@Akhilez)
    Source code(tar.gz)
    Source code(zip)
  • v1.6.0(Jul 12, 2022)

    SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

    Breaking Changes:

    • Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former register_policy helper, policy_base parameter and using policy_aliases static attributes instead (@Gregwar)
    • SB3 now requires PyTorch >= 1.11
    • Changed the default network architecture when using CnnPolicy or MultiInputPolicy with SAC or DDPG/TD3, share_features_extractor is now set to False by default and the net_arch=[256, 256] (instead of net_arch=[] that was before)

    SB3-Contrib

    • Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53

    Bug Fixes:

    • Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
    • Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
    • Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
    • Fixed a bug in DummyVecEnv's and SubprocVecEnv's seeding function. None value was unchecked (@ScheiklP)
    • Fixed a bug where EvalCallback would crash when trying to synchronize VecNormalize stats when observation normalization was disabled
    • Added a check for unbounded actions
    • Fixed issues due to newer version of protobuf (tensorboard) and sphinx
    • Fix exception causes all over the codebase (@cool-RR)
    • Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede)
    • Fixed a bug in kl_divergence check that would fail when using numpy arrays with MultiCategorical distribution

    Others:

    • Upgraded to Python 3.7+ syntax using pyupgrade
    • Removed redundant double-check for nested observations from BaseAlgorithm._wrap_env (@TibiGG)

    Documentation:

    • Added link to gym doc and gym env checker
    • Fix typo in PPO doc (@bcollazo)
    • Added link to PPO ICLR blog post
    • Added remark about breaking Markov assumption and timeout handling
    • Added doc about MLFlow integration via custom logger (@git-thor)
    • Updated Huggingface integration doc
    • Added copy button for code snippets
    • Added doc about EnvPool and Isaac Gym support
    Source code(tar.gz)
    Source code(zip)
  • v1.5.0(Mar 25, 2022)

    SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

    Breaking Changes:

    • Switched minimum Gym version to 0.21.0.

    New Features:

    • Added StopTrainingOnNoModelImprovement to callback collection (@caburu)
    • Makes the length of keys and values in HumanOutputFormat configurable, depending on desired maximum width of output.
    • Allow PPO to turn of advantage normalization (see PR #763) @vwxyzjn

    SB3-Contrib

    • coming soon: Cross Entropy Method, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/62

    Bug Fixes:

    • Fixed a bug in VecMonitor. The monitor did not consider the info_keywords during stepping (@ScheiklP)
    • Fixed a bug in HumanOutputFormat. Distinct keys truncated to the same prefix would overwrite each others value, resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases with very long keys.)
    • Routing all the nn.Module calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)
    • Fixed a bug in VecNormalize where error occurs when norm_obs is set to False for environment with dictionary observation (@buoyancy99)
    • Set default env argument to None in HerReplayBuffer.sample (@qgallouedec)
    • Fix batch_size typing in DQN (@qgallouedec)
    • Fixed sample normalization in DictReplayBuffer (@qgallouedec)

    Others:

    • Fixed pytest warnings
    • Removed parameter remove_time_limit_termination in off policy algorithms since it was dead code (@Gregwar)

    Documentation:

    • Added doc on Hugging Face integration (@simoninithomas)
    • Added furuta pendulum project to project list (@armandpl)
    • Fix indentation 2 spaces to 4 spaces in custom env documentation example (@Gautam-J)
    • Update MlpExtractor docstring (@gianlucadecola)
    • Added explanation of the logger output
    • Update Directly Accessing The Summary Writer in tensorboard integration (@xy9485)

    Full Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v1.4.0...v1.5.0

    Source code(tar.gz)
    Source code(zip)
  • v1.4.0(Jan 19, 2022)

    SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

    Breaking Changes:

    • Dropped python 3.6 support (as announced in previous release)
    • Renamed mask argument of the predict() method to episode_start (used with RNN policies only)
    • local variables action, done and reward were renamed to their plural form for offpolicy algorithms (actions, dones, rewards), this may affect custom callbacks.
    • Removed episode_reward field from RolloutReturn() type

    Warning:

    An update to the HER algorithm is planned to support multi-env training and remove the max episode length constrain. (see PR #704) This will be a backward incompatible change (model trained with previous version of HER won't work with the new version).

    New Features:

    • Added norm_obs_keys param for VecNormalize wrapper to configure which observation keys to normalize (@kachayev)
    • Added experimental support to train off-policy algorithms with multiple envs (note: HerReplayBuffer currently not supported)
    • Handle timeout termination properly for on-policy algorithms (when using TimeLimit)
    • Added skip option to VecTransposeImage to skip transforming the channel order when the heuristic is wrong
    • Added copy() and combine() methods to RunningMeanStd

    SB3-Contrib

    • Added Trust Region Policy Optimization (TRPO) (@cyprienc)
    • Added Augmented Random Search (ARS) (@sgillen)
    • Coming soon: PPO LSTM, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53

    Bug Fixes:

    • Fixed a bug where set_env() with VecNormalize would result in an error with off-policy algorithms (thanks @cleversonahum)
    • FPS calculation is now performed based on number of steps performed during last learn call, even when reset_num_timesteps is set to False (@kachayev)
    • Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
    • Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
    • The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
    • Fixed a bug in VecFrameStack with channel first image envs, where the terminal observation would be wrongly created.

    Others:

    • Added a warning in the env checker when not using np.float32 for continuous actions
    • Improved test coverage and error message when checking shape of observation
    • Added newline="\n" when opening CSV monitor files so that each line ends with \r\n instead of \r\r\n on Windows while Linux environments are not affected (@hsuehch)
    • Fixed device argument inconsistency (@qgallouedec)

    Documentation:

    • Add drivergym to projects page (@theDebugger811)
    • Add highway-env to projects page (@eleurent)
    • Add tactile-gym to projects page (@ac-93)
    • Fix indentation in the RL tips page (@cove9988)
    • Update GAE computation docstring
    • Add documentation on exporting to TFLite/Coral
    • Added JMLR paper and updated citation
    • Added link to RL Tips and Tricks video
    • Updated BaseAlgorithm.load docstring (@Demetrio92)
    • Added a note on load behavior in the examples (@Demetrio92)
    • Updated SB3 Contrib doc
    • Fixed A2C and migration guide guidance on how to set epsilon with RMSpropTFLike (@thomasgubler)
    • Fixed custom policy documentation (@IperGiove)
    • Added doc on Weights & Biases integration
    Source code(tar.gz)
    Source code(zip)
  • v1.3.0(Oct 23, 2021)

    WARNING: This version will be the last one supporting Python 3.6 (end of life in Dec 2021). We highly recommend you to upgrade to Python >= 3.7.

    SB3-Contrib changelog: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/releases/tag/v1.3.0

    Breaking Changes:

    • sde_net_arch argument in policies is deprecated and will be removed in a future version.

    • _get_latent (ActorCriticPolicy) was removed

    • All logging keys now use underscores instead of spaces (@timokau). Concretely this changes:

      • time/total timesteps to time/total_timesteps for off-policy algorithms (PPO and A2C) and the eval callback (on-policy algorithms already used the underscored version),
      • rollout/exploration rate to rollout/exploration_rate and
      • rollout/success rate to rollout/success_rate.

    New Features:

    • Added methods get_distribution and predict_values for ActorCriticPolicy for A2C/PPO/TRPO (@cyprienc)
    • Added methods forward_actor and forward_critic for MlpExtractor
    • Added sb3.get_system_info() helper function to gather version information relevant to SB3 (e.g., Python and PyTorch version)
    • Saved models now store system information where agent was trained, and load functions have print_system_info parameter to help debugging load issues.

    Bug Fixes:

    • Fixed dtype of observations for SimpleMultiObsEnv
    • Allow VecNormalize to wrap discrete-observation environments to normalize reward when observation normalization is disabled.
    • Fixed a bug where DQN would throw an error when using Discrete observation and stochastic actions
    • Fixed a bug where sub-classed observation spaces could not be used
    • Added force_reset argument to load() and set_env() in order to be able to call learn(reset_num_timesteps=False) with a new environment

    Others:

    • Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
    • Improved error message when using dict observation with the wrong policy
    • Improved error message when using EvalCallback with two envs not wrapped the same way.
    • Added additional infos about supported python version for PyPi in setup.py

    Documentation:

    • Add Rocket League Gym to list of supported projects (@AechPro)
    • Added gym-electric-motor to project page (@wkirgsn)
    • Added policy-distillation-baselines to project page (@CUN-bjy)
    • Added ONNX export instructions (@batu)
    • Update read the doc env (fixed docutils issue)
    • Fix PPO environment name (@IljaAvadiev)
    • Fix custom env doc and add env registration example
    • Update algorithms from SB3 Contrib
    • Use underscores for numeric literals in examples to improve clarity
    Source code(tar.gz)
    Source code(zip)
  • v1.2.0(Sep 8, 2021)

    Breaking Changes:

    • SB3 now requires PyTorch >= 1.8.1
    • VecNormalize ret attribute was renamed to returns

    Bug Fixes:

    • Hotfix for VecNormalize where the observation filter was not updated at reset (thanks @vwxyzjn)
    • Fixed model predictions when using batch normalization and dropout layers by calling train() and eval() (@davidblom603)
    • Fixed model training for DQN, TD3 and SAC so that their target nets always remain in evaluation mode (@ayeright)
    • Passing gradient_steps=0 to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment during the rollout in previous versions)

    Others:

    • Enabled Python 3.9 in GitHub CI
    • Fixed type annotations
    • Refactored predict() by moving the preprocessing to obs_to_tensor() method

    Documentation:

    • Updated multiprocessing example
    • Added example of VecEnvWrapper
    • Added a note about logging to tensorboard more often
    • Added warning about simplicity of examples and link to RL zoo (@MihaiAnca13)
    Source code(tar.gz)
    Source code(zip)
  • v1.1.0(Jul 2, 2021)

    Breaking Changes

    • All customs environments (e.g. the BitFlippingEnv or IdentityEnv) were moved to stable_baselines3.common.envs folder
    • Refactored HER which is now the HerReplayBuffer class that can be passed to any off-policy algorithm
    • Handle timeout termination properly for off-policy algorithms (when using TimeLimit)
    • Renamed _last_dones and dones to _last_episode_starts and episode_starts in RolloutBuffer.
    • Removed ObsDictWrapper as Dict observation spaces are now supported
      her_kwargs = dict(n_sampled_goal=2, goal_selection_strategy="future", online_sampling=True)
      # SB3 < 1.1.0
      # model = HER("MlpPolicy", env, model_class=SAC, **her_kwargs)
      # SB3 >= 1.1.0:
      model = SAC("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=her_kwargs)
    
    • Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro)
    • Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro)
    • Removed parameter channels_last from is_image_space as it can be inferred.
    • The logger object is now an attribute model.logger that be set by the user using model.set_logger()
    • Changed the signature of logger.configure and utils.configure_logger, they now return a Logger object
    • Removed Logger.CURRENT and Logger.DEFAULT
    • Moved warn(), debug(), log(), info(), dump() methods to the Logger class
    • .learn() now throws an import error when the user tries to log to tensorboard but the package is not installed

    New Features

    • Added support for single-level Dict observation space (@JadenTravnik)
    • Added DictRolloutBuffer DictReplayBuffer to support dictionary observations (@JadenTravnik)
    • Added StackedObservations and StackedDictObservations that are used within VecFrameStack
    • Added simple 4x4 room Dict test environments
    • HerReplayBuffer now supports VecNormalize when online_sampling=False
    • Added VecMonitor and VecExtractDictObs wrappers to handle gym3-style vectorized environments (@vwxyzjn)
    • Ignored the terminal observation if the it is not provided by the environment such as the gym3-style vectorized environments. (@vwxyzjn)
    • Added policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
    • Added support for image observation when using HER
    • Added replay_buffer_class and replay_buffer_kwargs arguments to off-policy algorithms
    • Added kl_divergence helper for Distribution classes (@09tangriro)
    • Added support for vector environments with num_envs > 1 (@benblack769)
    • Added wrapper_kwargs argument to make_vec_env (@amy12xx)

    Bug Fixes

    • Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
    • Fixed loading of ent_coef for SAC and TQC, it was not optimized anymore (thanks @Atlis)
    • Fixed saving of A2C and PPO policy when using gSDE (thanks @liusida)
    • Fixed a bug where no output would be shown even if verbose>=1 after passing verbose=0 once
    • Fixed observation buffers dtype in DictReplayBuffer (@c-rizz)
    • Fixed EvalCallback tensorboard logs being logged with the incorrect timestep. They are now written with the timestep at which they were recorded. (@skandermoalla)

    Others

    • Added flake8-bugbear to tests dependencies to find likely bugs
    • Updated env_checker to reflect support of dict observation spaces
    • Added Code of Conduct
    • Added tests for GAE and lambda return computation
    • Updated distribution entropy test (thanks @09tangriro)
    • Added sanity check batch_size > 1 in PPO to avoid NaN in advantage normalization

    Documentation:

    • Added gym pybullet drones project (@JacopoPan)
    • Added link to SuperSuit in projects (@justinkterry)
    • Fixed DQN example (thanks @ltbd78)
    • Clarified channel-first/channel-last recommendation
    • Update sphinx environment installation instructions (@tom-doerr)
    • Clarified pip installation in Zsh (@tom-doerr)
    • Clarified return computation for on-policy algorithms (TD(lambda) estimate was used)
    • Added example for using ProcgenEnv
    • Added note about advanced custom policy example for off-policy algorithms
    • Fixed DQN unicode checkmarks
    • Updated migration guide (@juancroldan)
    • Pinned docutils==0.16 to avoid issue with rtd theme
    • Clarified callback save_freq definition
    • Added doc on how to pass a custom logger
    • Remove recurrent policies from A2C docs (@bstee615)
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Mar 17, 2021)

    First Major Version

    Blog post: https://araffin.github.io/post/sb3/

    100+ pre-trained models in the zoo: https://github.com/DLR-RM/rl-baselines3-zoo

    Breaking Changes:

    • Removed stable_baselines3.common.cmd_util (already deprecated), please use env_util instead

    Warning

    A refactoring of the HER algorithm is planned together with support for dictionary observations (see PR #243 and #351) This will be a backward incompatible change (model trained with previous version of HER won't work with the new version).

    New Features:

    • Added support for custom_objects when loading models

    Bug Fixes:

    • Fixed a bug with DQN predict method when using deterministic=False with image space

    Documentation:

    • Fixed examples
    • Added new project using SB3: rl_reach (@PierreExeter)
    • Added note about slow-down when switching to PyTorch
    • Add a note on continual learning and resetting environment
    • Updated RL-Zoo to reflect the fact that is it more than a collection of trained agents
    • Added images to illustrate the training loop and custom policies (created with https://excalidraw.com/)
    • Updated the custom policy section
    Source code(tar.gz)
    Source code(zip)
  • v1.0rc1(Mar 6, 2021)

  • v0.11.1(Feb 27, 2021)

    Breaking Changes:

    • evaluate_policy now returns rewards/episode lengths from a Monitor wrapper if one is present, this allows to return the unnormalized reward in the case of Atari games for instance.
    • Renamed common.vec_env.is_wrapped to common.vec_env.is_vecenv_wrapped to avoid confusion with the new is_wrapped() helper
    • Renamed _get_data() to _get_constructor_parameters() for policies (this affects independent saving/loading of policies)
    • Removed n_episodes_rollout and merged it with train_freq, which now accepts a tuple (frequency, unit):
    • replay_buffer in collect_rollout is no more optional
    
      # SB3 < 0.11.0
      # model = SAC("MlpPolicy", env, n_episodes_rollout=1, train_freq=-1)
      # SB3 >= 0.11.0:
      model = SAC("MlpPolicy", env, train_freq=(1, "episode"))
    

    New Features:

    • Add support for VecFrameStack to stack on first or last observation dimension, along with automatic check for image spaces.
    • VecFrameStack now has a channels_order argument to tell if observations should be stacked on the first or last observation dimension (originally always stacked on last).
    • Added common.env_util.is_wrapped and common.env_util.unwrap_wrapper functions for checking/unwrapping an environment for specific wrapper.
    • Added env_is_wrapped() method for VecEnv to check if its environments are wrapped with given Gym wrappers.
    • Added monitor_kwargs parameter to make_vec_env and make_atari_env
    • Wrap the environments automatically with a Monitor wrapper when possible.
    • EvalCallback now logs the success rate when available (is_success must be present in the info dict)
    • Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio)
    • Add support for text records to Logger. (@lorenz-h)

    Bug Fixes:

    • Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
    • Fixed DQN predict method when using single gym.Env with deterministic=False
    • Fixed bug that the arguments order of explained_variance() in ppo.py and a2c.py is not correct (@thisray)
    • Fixed bug where full HerReplayBuffer leads to an index error. (@megan-klaiber)
    • Fixed bug where replay buffer could not be saved if it was too big (> 4 Gb) for python<3.8 (thanks @hn2)
    • Added informative PPO construction error in edge-case scenario where n_steps * n_envs = 1 (size of rollout buffer), which otherwise causes downstream breaking errors in training (@decodyng)
    • Fixed discrete observation space support when using multiple envs with A2C/PPO (thanks @ardabbour)
    • Fixed a bug for TD3 delayed update (the update was off-by-one and not delayed when train_freq=1)
    • Fixed numpy warning (replaced np.bool with bool)
    • Fixed a bug where VecNormalize was not normalizing the terminal observation
    • Fixed a bug where VecTranspose was not transposing the terminal observation
    • Fixed a bug where the terminal observation stored in the replay buffer was not the right one for off-policy algorithms
    • Fixed a bug where action_noise was not used when using HER (thanks @ShangqunYu)
    • Fixed a bug where train_freq was not properly converted when loading a saved model

    Others:

    • Add more issue templates
    • Add signatures to callable type annotations (@ernestum)
    • Improve error message in NatureCNN
    • Added checks for supported action spaces to improve clarity of error messages for the user
    • Renamed variables in the train() method of SAC, TD3 and DQN to match SB3-Contrib.
    • Updated docker base image to Ubuntu 18.04
    • Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch)
    • Added warning for PPO when n_steps * n_envs is not a multiple of batch_size (last mini-batch truncated) (@decodyng)
    • Removed some warnings in the tests

    Documentation:

    • Updated algorithm table
    • Minor docstring improvements regarding rollout (@stheid)
    • Fix migration doc for A2C (epsilon parameter)
    • Fix clip_range docstring
    • Fix duplicated parameter in EvalCallback docstring (thanks @tfederico)
    • Added example of learning rate schedule
    • Added SUMO-RL as example project (@LucasAlegre)
    • Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
    • Added SB3-Contrib page
    • Fix bug in the example code of DQN (@AptX395)
    • Add example on how to access the tensorboard summary writer directly. (@lorenz-h)
    • Updated migration guide
    • Updated custom policy doc (separate policy architecture recommended)
    • Added a note about OpenCV headless version
    • Corrected typo on documentation (@mschweizer)
    • Provide the environment when loading the model in the examples (@lorepieri8)
    Source code(tar.gz)
    Source code(zip)
  • v0.10.0(Oct 28, 2020)

    Breaking Changes

    • Warning: Renamed common.cmd_util to common.env_util for clarity (affects make_vec_env and make_atari_env functions)

    New Features

    • Allow custom actor/critic network architectures using net_arch=dict(qf=[400, 300], pi=[64, 64]) for off-policy algorithms (SAC, TD3, DDPG)
    • Added Hindsight Experience Replay HER. (@megan-klaiber)
    • VecNormalize now supports gym.spaces.Dict observation spaces
    • Support logging videos to Tensorboard (@SwamyDev)
    • Added share_features_extractor argument to SAC and TD3 policies

    Bug Fixes

    • Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
    • Fixed potential issue when loading a different environment
    • Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
    • Make make_vec_env support the env_kwargs argument when using an env ID str (@ManifoldFR)
    • Fix model creation initializing CUDA even when device="cpu" is provided
    • Fix check_env not checking if the env has a Dict actionspace before calling _check_nan (@wmmc88)
    • Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
    • Fixed feature extractor bug for target network where the same net was shared instead of being separate. This bug affects SAC, DDPG and TD3 when using CnnPolicy (or custom feature extractor)
    • Fixed a bug when passing an environment when loading a saved model with a CnnPolicy, the passed env was not wrapped properly (the bug was introduced when implementing HER so it should not be present in previous versions)

    Others

    • Improved typing coverage
    • Improved error messages for unsupported spaces
    • Added .vscode to the gitignore

    Documentation

    • Added first draft of migration guide
    • Added intro to imitation library (@shwang)
    • Enabled doc for CnnPolicies
    • Added advanced saving and loading example
    • Added base doc for exporting models
    • Added example for getting and setting model parameters
    Source code(tar.gz)
    Source code(zip)
  • v0.9.0(Oct 4, 2020)

    Breaking Changes:

    • Removed device keyword argument of policies; use policy.to(device) instead. (@qxcv)
    • Rename BaseClass.get_torch_variables -> BaseClass._get_torch_save_params and BaseClass.excluded_save_params -> BaseClass._excluded_save_params
    • Renamed saved items tensors to pytorch_variables for clarity
    • make_atari_env, make_vec_env and set_random_seed must be imported with (and not directly from stable_baselines3.common):
    from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
    from stable_baselines3.common.utils import set_random_seed
    

    New Features:

    • Added unwrap_vec_wrapper() to common.vec_env to extract VecEnvWrapper if needed
    • Added StopTrainingOnMaxEpisodes to callback collection (@xicocaio)
    • Added device keyword argument to BaseAlgorithm.load() (@liorcohen5)
    • Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
    • Added get_parameters and set_parameters for accessing/setting parameters of the agent
    • Added actor/critic loss logging for TD3. (@mloo3)

    Bug Fixes:

    • Fixed a bug where the environment was reset twice when using evaluate_policy
    • Fix logging of clip_fraction in PPO (@diditforlulz273)
    • Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., device="cuda:0" (@liorcohen5)
    • Fixed a bug when the random seed was not properly set on cuda when passing the GPU index

    Others:

    • Improve typing coverage of the VecEnv
    • Fix type annotation of make_vec_env (@ManifoldFR)
    • Removed AlreadySteppingError and NotSteppingError that were not used
    • Fixed typos in SAC and TD3
    • Reorganized functions for clarity in BaseClass (save/load functions close to each other, private functions at top)
    • Clarified docstrings on what is saved and loaded to/from files
    • Simplified save_to_zip_file function by removing duplicate code
    • Store library version along with the saved models
    • DQN loss is now logged

    Documentation:

    • Added StopTrainingOnMaxEpisodes details and example (@xicocaio)
    • Updated custom policy section (added custom feature extractor example)
    • Re-enable sphinx_autodoc_typehints
    • Updated doc style for type hints and remove duplicated type hints
    Source code(tar.gz)
    Source code(zip)
  • v0.8.0(Aug 3, 2020)

    Breaking Changes:

    • AtariWrapper and other Atari wrappers were updated to match SB2 ones
    • save_replay_buffer now receives as argument the file path instead of the folder path (@tirafesi)
    • Refactored Critic class for TD3 and SAC, it is now called ContinuousCritic and has an additional parameter n_critics
    • SAC and TD3 now accept an arbitrary number of critics (e.g. policy_kwargs=dict(n_critics=3)) instead of only 2 previously

    New Features:

    • Added DQN Algorithm (@Artemis-Skade)
    • Buffer dtype is now set according to action and observation spaces for ReplayBuffer
    • Added warning when allocation of a buffer may exceed the available memory of the system when psutil is available
    • Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
    • Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
    • Added DDPG algorithm as a special case of TD3.
    • Introduced BaseModel abstract parent for BasePolicy, which critics inherit from.

    Bug Fixes:

    • Fixed a bug in the close() method of SubprocVecEnv, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
    • Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states
    • Use cloudpickle.load instead of pickle.load in CloudpickleWrapper. (@shwang)
    • Fixed a bug with orthogonal initialization when bias=False in custom policy (@rk37)
    • Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
    • Fixed DQN target network sharing feature extractor with the main network.
    • Fixed storing correct dones in on-policy algorithm rollout collection. (@andyshih12)
    • Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.

    Others:

    • Refactored off-policy algorithm to share the same .learn() method
    • Split the collect_rollout() method for off-policy algorithms
    • Added _on_step() for off-policy base class
    • Optimized replay buffer size by removing the need of next_observations numpy array
    • Optimized polyak updates (1.5-1.95 speedup) through inplace operations (@PartiallyTyped)
    • Switch to black codestyle and added make format, make check-codestyle and commit-checks
    • Ignored errors from newer pytype version
    • Added a check when using gSDE
    • Removed codacy dependency from Dockerfile
    • Added common.sb2_compat.RMSpropTFLike optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow.

    Documentation:

    • Updated notebook links
    • Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
    • Added Unity reacher to the projects page (@koulakis)
    • Added PyBullet colab notebook
    • Fixed typo in PPO example code (@joeljosephjin)
    • Fixed typo in custom policy doc (@RaphaelWag)
    Source code(tar.gz)
    Source code(zip)
  • v0.7.0(Jun 10, 2020)

    Breaking Changes:

    • render() method of VecEnvs now only accept one argument: mode

    • Created new file common/torch_layers.py, similar to SB refactoring

      • Contains all PyTorch network layer definitions and feature extractors: MlpExtractor, create_mlp, NatureCNN
    • Renamed BaseRLModel to BaseAlgorithm (along with offpolicy and onpolicy variants)

    • Moved on-policy and off-policy base algorithms to common/on_policy_algorithm.py and common/off_policy_algorithm.py, respectively.

    • Moved PPOPolicy to ActorCriticPolicy in common/policies.py

    • Moved PPO (algorithm class) into OnPolicyAlgorithm (common/on_policy_algorithm.py), to be shared with A2C

    • Moved following functions from BaseAlgorithm:

      • _load_from_file to load_from_zip_file (save_util.py)
      • _save_to_file_zip to save_to_zip_file (save_util.py)
      • safe_mean to safe_mean (utils.py)
      • check_env to check_for_correct_spaces (utils.py. Renamed to avoid confusion with environment checker tools)
    • Moved static function _is_vectorized_observation from common/policies.py to common/utils.py under name is_vectorized_observation.

    • Removed {save,load}_running_average functions of VecNormalize in favor of load/save.

    • Removed use_gae parameter from RolloutBuffer.compute_returns_and_advantage.

    Bug Fixes:

    • Fixed render() method for VecEnvs
    • Fixed seed() method for SubprocVecEnv
    • Fixed loading on GPU for testing when using gSDE and deterministic=False
    • Fixed register_policy to allow re-registering same policy for same sub-class (i.e. assign same value to same key).
    • Fixed a bug where the gradient was passed when using gSDE with PPO/A2C, this does not affect SAC

    Others:

    • Re-enable unsafe fork start method in the tests (was causing a deadlock with tensorflow)
    • Added a test for seeding SubprocVecEnv and rendering
    • Fixed reference in NatureCNN (pointed to older version with different network architecture)
    • Fixed comments saying "CxWxH" instead of "CxHxW" (same style as in torch docs / commonly used)
    • Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy").
    • Renamed progress (value from 1 in start of training to 0 in end) to progress_remaining.
    • Added policies.py files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).
    • Added some missing tests for VecNormalize, VecCheckNan and PPO.

    Documentation:

    • Added a paragraph on "MlpPolicy"/"CnnPolicy" and policy naming scheme under "Developer Guide"
    • Fixed second-level listing in changelog
    Source code(tar.gz)
    Source code(zip)
  • v0.6.0(Jun 1, 2020)

    Breaking Changes:

    • Remove State-Dependent Exploration (SDE) support for TD3
    • Methods were renamed in the logger:
      • logkv -> record, writekvs -> write, writeseq -> write_sequence,
      • logkvs -> record_dict, dumpkvs -> dump,
      • getkvs -> get_log_dict, logkv_mean -> record_mean,

    New Features:

    • Added env checker (Sync with Stable Baselines)
    • Added VecCheckNan and VecVideoRecorder (Sync with Stable Baselines)
    • Added determinism tests
    • Added cmd_util and atari_wrappers
    • Added support for MultiDiscrete and MultiBinary observation spaces (@rolandgvc)
    • Added MultiCategorical and Bernoulli distributions for PPO/A2C (@rolandgvc)
    • Added support for logging to tensorboard (@rolandgvc)
    • Added VectorizedActionNoise for continuous vectorized environments (@PartiallyTyped)
    • Log evaluation in the EvalCallback using the logger

    Bug Fixes:

    • Fixed a bug that prevented model trained on cpu to be loaded on gpu
    • Fixed version number that had a new line included
    • Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size
    • Fixed sde_sample_freq that was not taken into account for SAC
    • Pass logger module to BaseCallback otherwise they cannot write in the one used by the algorithms

    Others:

    • Renamed to Stable-Baseline3
    • Added Dockerfile
    • Sync VecEnvs with Stable-Baselines
    • Update requirement: gym>=0.17
    • Added .readthedoc.yml file
    • Added flake8 and make lint command
    • Added Github workflow
    • Added warning when passing both train_freq and n_episodes_rollout to Off-Policy Algorithms

    Documentation:

    • Added most documentation (adapted from Stable-Baselines)
    • Added link to CONTRIBUTING.md in the README (@kinalmehta)
    • Added gSDE project and update docstrings accordingly
    • Fix TD3 example code block
    Source code(tar.gz)
    Source code(zip)
Owner
DLR-RM
German Aerospace Center (DLR) - Institute of Robotics and Mechatronics (RM) - open source projects
DLR-RM
PyTorch implementations of deep reinforcement learning algorithms and environments

Deep Reinforcement Learning Algorithms with PyTorch This repository contains PyTorch implementations of deep reinforcement learning algorithms and env

Petros Christodoulou 4.7k Jan 4, 2023
Pytorch implementations of popular off-policy multi-agent reinforcement learning algorithms, including QMix, VDN, MADDPG, and MATD3.

Off-Policy Multi-Agent Reinforcement Learning (MARL) Algorithms This repository contains implementations of various off-policy multi-agent reinforceme

null 183 Dec 28, 2022
Independent and minimal implementations of some reinforcement learning algorithms using PyTorch (including PPO, A3C, A2C, ...).

PyTorch RL Minimal Implementations There are implementations of some reinforcement learning algorithms, whose characteristics are as follow: Less pack

Gemini Light 4 Dec 31, 2022
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
Lyapunov-guided Deep Reinforcement Learning for Stable Online Computation Offloading in Mobile-Edge Computing Networks

PyTorch code to reproduce LyDROO algorithm [1], which is an online computation offloading algorithm to maximize the network data processing capability subject to the long-term data queue stability and average power constraints. It applies Lyapunov optimization to decouple the multi-stage stochastic MINLP into deterministic per-frame MINLP subproblems and solves each subproblem via DROO algorithm. It includes:

Liang HUANG 87 Dec 28, 2022
Machine Learning From Scratch. Bare bones NumPy implementations of machine learning models and algorithms with a focus on accessibility. Aims to cover everything from linear regression to deep learning.

Machine Learning From Scratch About Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The purpose

Erik Linder-Norén 21.8k Jan 9, 2023
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 5, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Offline Multi-Agent Reinforcement Learning Implementations: Solving Overcooked Game with Data-Driven Method

Overcooked-AI We suppose to apply traditional offline reinforcement learning technique to multi-agent algorithm. In this repository, we implemented be

Baek In-Chang 14 Sep 16, 2022
Reinforcement learning framework and algorithms implemented in PyTorch.

Reinforcement learning framework and algorithms implemented in PyTorch.

Robotic AI & Learning Lab Berkeley 2.1k Jan 4, 2023
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 8, 2023
This is the PyTorch implementation of GANs N’ Roses: Stable, Controllable, Diverse Image to Image Translation

Official PyTorch repo for GAN's N' Roses. Diverse im2im and vid2vid selfie to anime translation.

null 1.1k Jan 1, 2023
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
Multi Task RL Baselines

MTRL Multi Task RL Algorithms Contents Introduction Setup Usage Documentation Contributing to MTRL Community Acknowledgements Introduction M

Facebook Research 171 Jan 9, 2023
Baselines for TrajNet++

TrajNet++ : The Trajectory Forecasting Framework PyTorch implementation of Human Trajectory Forecasting in Crowds: A Deep Learning Perspective TrajNet

VITA lab at EPFL 183 Jan 5, 2023
Provide baselines and evaluation metrics of the task: traffic flow prediction

Note: This repo is adpoted from https://github.com/UNIMIBInside/Smart-Mobility-Prediction. Due to technical reasons, I did not fork their code. Introd

Zhangzhi Peng 11 Nov 2, 2022
Learning to Initialize Neural Networks for Stable and Efficient Training

GradInit This repository hosts the code for experiments in the paper, GradInit: Learning to Initialize Neural Networks for Stable and Efficient Traini

Chen Zhu 124 Dec 30, 2022
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

null 120 Dec 28, 2022