Deep reinforcement learning library built on top of Neural Network Libraries

Overview

License Build status

Deep Reinforcement Learning Library built on top of Neural Network Libraries

NNablaRL is a deep reinforcement learning library built on top of Neural Network Libraries that is intended to be used for research, development and production.

Installation

Installing NNablaRL is easy!

$ pip install nnabla-rl

NNablaRL only supports Python version >= 3.6 and NNabla version >= 1.17.

Enabling GPU accelaration (Optional)

NNablaRL algorithms run on CPU by default. To run the algorithm on GPU, first install nnabla-ext-cuda as follows. (Replace [cuda-version] depending on the CUDA version installed on your machine.)

$ pip install nnabla-ext-cuda[cuda-version]
# Example installation. Supposing CUDA 11.0 is installed on your machine.
$ pip install nnabla-ext-cuda110

After installing nnabla-ext-cuda, set the gpu id to run the algorithm on through algorithm's configuration.

import nnabla_rl.algorithms as A

config = A.DQNConfig(gpu_id=0) # Use gpu 0. If negative, will run on CPU.
dqn = A.DQN(env, config=config)
...

Features

Friendly API

NNablaRL has friendly Python APIs which enables to start training with only 3 lines of python code.

import nnabla_rl
import nnabla_rl.algorithms as A
from nnabla_rl.utils.reproductions import build_atari_env

env = build_atari_env("BreakoutNoFrameskip-v4") # 1
dqn = A.DQN(env)  # 2
dqn.train(env)  # 3

To get more details about NNablaRL, see documentation and examples.

Many builtin algorithms

Most of famous/SOTA deep reinforcement learning algorithms, such as DQN, SAC, BCQ, GAIL, etc., are implemented in NNablaRL. Implemented algorithms are carefully tested and evaluated. You can easily start training your agent using these verified implementations.

For the list of implemented algorithms see here.

You can also find the reproduction and evaluation results of each algorithm here.
Note that you may not get completely the same results when running the reproduction code on your computer. The result may slightly change depending on your machine, nnabla/nnabla-rl's package version, etc.

Seemless switching of online and offline training

In reinforcement learning, there are two main training procedures, online and offline, to train the agent. Online training is a training procedure that executes both data collection and network update alternately. Conversely, offline training is a training procedure that updates the network using only existing data. With NNablaRL, you can switch these two training procedures seemlessly. For example, as shown below, you can easily train a robot's controller online using simulated environment and finetune it offline with real robot dataset.

import nnabla_rl
import nnabla_rl.algorithms as A

simulator = get_simulator() # This is just an example. Assuming that simulator exists
dqn = A.DQN(simulator)
# train online for 1M iterations
dqn.train_online(simulator, total_iterations=1000000)

real_data = get_real_robot_data() # This is also an example. Assuming that you have real robot data
# fine tune the agent offline for 10k iterations using real data
dqn.train_offline(real_data, total_iterations=10000)

Getting started

Try below interactive demos to get started.
You can run it directly on Colab from the links in the table below.

Title Notebook Target RL task
Simple reinforcement learning training to get started Open In Colab Pendulum
Learn how to use training algorithms Open In Colab Pendulum
Learn how to use customized network model for training Open In Colab Mountain car
Learn how to use different network solver for training Open In Colab Pendulum
Learn how to use different replay buffer for training Open In Colab Pendulum
Learn how to use your own environment for training Open In Colab Customized environment
Atari game training example Open In Colab Atari games

Documentation

Full documentation is here.

Contribution guide

Any kind of contribution to NNablaRL is welcome! See the contribution guide for details.

License

NNablaRL is provided under the Apache License Version 2.0 license.

Comments
  • Update cem function interface

    Update cem function interface

    Updated interface of cross entropy function methods. The args, pop_size is now changed to sample_size. In addition, the given objective function to CEM function will be called with variable x which has (batch_size, sample_size, x_dim). This is different from previous interface. If you want to know the details, please see the function docs.

    opened by sbsekiguchi 1
  • Add implementation for RNN support and DRQN algorithm

    Add implementation for RNN support and DRQN algorithm

    Add RNN model support and DRQN algorithm.

    Following trainers will support RNN-model.

    • Q value-based trainers
    • Deterministic gradient and Soft policy trainers

    Other trainers can support RNN models in future but is not implemented in the initial release.

    See this paper for the details of the DRQN algorithm.

    opened by ishihara-y 1
  • Implement SACD

    Implement SACD

    This PR implements SAC-D algorithm. https://arxiv.org/abs/2206.13901

    These changes have been made:

    • New environments with factored reward functions have been added
      • FactoredLunarLanderContinuousV2NNablaRL-v1
      • FactoredAntV4NNablaRL-v1
      • FactoredHopperV4NNablaRL-v1
      • FactoredHalfCheetahV4NNablaRL-v1
      • FactoredWalker2dV4NNablaRL-v1
      • FactoredHumanoidV4NNablaRL-v1
    • SACD algorithms has been added
    • SoftQDTrainer has been added
    • _InfluenceMetricsEvaluator has been added
    • reproduction script has been added (not benchmarked yet)

    visualizing influence metrics

    import gym
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    import nnabla_rl.algorithms as A
    import nnabla_rl.hooks as H
    import nnabla_rl.writers as W
    from nnabla_rl.utils.evaluator import EpisodicEvaluator
    
    env = gym.make("FactoredLunarLanderContinuousV2NNablaRL-v1")
    eval_env = gym.make("FactoredLunarLanderContinuousV2NNablaRL-v1")
    
    evaluation_hook = H.EvaluationHook(
        eval_env,
        EpisodicEvaluator(run_per_evaluation=10),
        timing=5000,
        writer=W.FileWriter(outdir="logdir", file_prefix='evaluation_result'),
    )
    iteration_num_hook = H.IterationNumHook(timing=100)
    
    config = A.SACDConfig(gpu_id=0, reward_dimension=9)
    sacd = A.SACD(env, config=config)
    sacd.set_hooks([iteration_num_hook, evaluation_hook])
    sacd.train_online(env, total_iterations=100000)
    
    influence_history = []
    
    state = env.reset()
    while True:
        action = sacd.compute_eval_action(state)
        influence = sacd.compute_influence_metrics(state, action)
        influence_history.append(influence)
        state, _, done, _ = env.step(action)
        if done:
            break
    
    influence_history = np.array(influence_history)
    for i, label in enumerate(["position", "velocity", "angle", "left_leg", "right_leg", "main_eingine", "side_engine", "failure", "success"]):
        plt.plot(influence_history[:, i], label=label)
    plt.xlabel("step")
    plt.ylabel("influence metrics")
    plt.legend()
    plt.show()
    

    image

    sample animation

    sample

    opened by ishihara-y 0
  • Add gmm and Update gaussian

    Add gmm and Update gaussian

    Added gmm and gaussian of the numpy models. In addition, updated the gaussian distribution's API.

    The API change is like following:

    Previous :

    batch_size = 10
    output_dim = 10
    input_shape = (batch_size, output_dim)
    mean = np.zeros(shape=input_shape)
    sigma = np.ones(shape=input_shape) * 5.
    ln_var = np.log(sigma) * 2.
    distribution = D.Gaussian(mean, ln_var)
    # return nn.Variable
    assert isinstance(distribution.sample(), nn.Variable)
    

    Updated:

    batch_size = 10
    output_dim = 10
    input_shape = (batch_size, output_dim)
    mean = np.zeros(shape=input_shape)
    sigma = np.ones(shape=input_shape) * 5.
    ln_var = np.log(sigma) * 2.
    # You have to pass the nn.Variable if you want to get nn.Variable as all class method's return.
    distribution = D.Gaussian(nn.Variable.from_numpy_array(mean), nn.Variable.from_numpy_array(ln_var))
    assert isinstance(distribution.sample(), nn.Variable)
    
    # If you pass np.ndarray, then all class methods return np.ndarray
    # Currently, only support without batch shape (i.e. mean.shape = (dims,), ln_var.shape = (dims, dims)).
    distribution = D.Gaussian(mean[0], np.diag(ln_var[0]))  # without batch
    assert isinstance(distribution.sample(), np.ndarray)
    
    opened by sbsekiguchi 0
  • Support nnabla-browser

    Support nnabla-browser

    • [x] add MonitorWriter
    • [x] save computational graph as nntxt

    example

    import gym
    
    import nnabla_rl.algorithms as A
    import nnabla_rl.hooks as H
    import nnabla_rl.writers as W
    from nnabla_rl.utils.evaluator import EpisodicEvaluator
    
    # save training computational graph
    training_graph_hook = H.TrainingGraphHook(outdir="test")
    
    # evaluation hook with nnabla's Monitor
    eval_env = gym.make("Pendulum-v0")
    evaluator = EpisodicEvaluator(run_per_evaluation=10)
    evaluation_hook = H.EvaluationHook(
        eval_env,
        evaluator,
        timing=10,
        writer=W.MonitorWriter(outdir="test", file_prefix='evaluation_result'),
    )
    
    env = gym.make("Pendulum-v0")
    sac = A.SAC(env)
    sac.set_hooks([training_graph_hook, evaluation_hook])
    
    sac.train_online(env, total_iterations=100)
    

    image image

    opened by ishihara-y 0
  • Add iLQR and LQR

    Add iLQR and LQR

    Implementation of Linear Quadratic Regulator (LQR) and iterative LQR algorithms.

    Co-authored-by: Yu Ishihara [email protected] Co-authored-by: Shunichi Sekiguchi [email protected]

    opened by ishihara-y 0
  • Check np_random instance and use correct randint alternative

    Check np_random instance and use correct randint alternative

    I am not sure when this change was made but in some environment, gym.unwrapped.np_random returns Generator instead of RandomState.

    # in case of RandomState
    # this line works
    gym.unwrapped.np_random.rand_int(...)
    # in case of Generator
    # rand_int does not exist and we must use integers as an alternative
    gym.unwrapped.np_random.integers(...)
    

    This PR will fix this issue and chooses correct function for sampling integers.

    opened by ishihara-y 0
  • Add icra2018 qtopt

    Add icra2018 qtopt

    opened by sbsekiguchi 0
Releases(v0.12.0)
Owner
Sony
Sony Group Corporation
Sony
Get charts, top artists and top songs WITHOUT LastFM API

LastFM Get charts, top artists and top songs WITHOUT LastFM API Usage Get stats (charts) We provide many filters and options to customize. Geo filter

null 4 Feb 11, 2022
Free and Open Source Machine Translation API. 100% self-hosted, no limits, no ties to proprietary services. Built on top of Argos Translate.

LibreTranslate Try it online! | API Docs Free and Open Source Machine Translation API, entirely self-hosted. Unlike other APIs, it doesn't rely on pro

UAV4GEO 3.5k Jan 3, 2023
A discord http interactions framework built on top of Sanic

snowfin An async discord http interactions framework built on top of Sanic Installing for now just install the package through pip via github # Unix b

kaj 13 Dec 15, 2022
A decentralized messaging daemon built on top of the Kademlia routing protocol.

parakeet-message A decentralized messaging daemon built on top of the Kademlia routing protocol. Now that you are done laughing... pictures what is it

Jonathan Abbott 3 Apr 23, 2022
Minimal Python client for the Iris API, built on top of Authlib and httpx.

??️ Iris Python Client Minimal Python client for the Iris API, built on top of Authlib and httpx. Installation pip install dioptra-iris-client Usage f

Dioptra 1 Jan 28, 2022
🏆 A ranked list of awesome machine learning Python libraries. Updated weekly.

Best-of Machine Learning with Python ?? A ranked list of awesome machine learning Python libraries. Updated weekly. This curated list contains 840 awe

Machine Learning Tooling 12.2k Jan 4, 2023
A Chip-8 emulator written using Python's default libraries

Chippure A Chip-8 emulator written using Python's default libraries. Instructions: Simply launch the .py file and type the name of the Chip8 ROM you w

null 5 Sep 27, 2022
Gnosis-py includes a set of libraries to work with Ethereum and Gnosis projects

Gnosis-py Gnosis-py includes a set of libraries to work with Ethereum and Gnosis projects: EthereumClient, a wrapper over Web3.py Web3 client includin

Gnosis 93 Dec 23, 2022
Python wrappers for INHECO ODTC and SCILA libraries by INHECO GmbH.

Python wrappers for INHECO ODTC and SCILA libraries by INHECO GmbH.

null 1 Feb 9, 2022
Policy and data administration, distribution, and real-time updates on top of Open Policy Agent

⚡ OPAL ⚡ Open Policy Administration Layer OPAL is an administration layer for Open Policy Agent (OPA), detecting changes to both policy and policy dat

null 8 Dec 7, 2022
A small bot to interact with the reddit API. Get top viewers and update the sidebar widget.

LiveStream_Reddit_Bot Get top twitch and facebook stream viewers for a game and update the sidebar widget and old reddit sidebar to show your communit

Tristan Wise 1 Nov 21, 2021
A telegram bot that sends a meme a day, from reddit's top meme of the day

MemeBot A telegram bot that sends a meme a day, from reddit's top meme of the day You can use the bot either with an external scheduler (ex: pythonany

Michele Vitulli 1 Dec 13, 2021
An advanced automatic top.gg dank memer voter that votes automatically for you.

Auto Dank Memer Voter An automatic dank memer voter that sends votes onto top.gg every 12 hours, unless their is captcha. I am working on a captcha de

null 6 Aug 27, 2022
Using Streamlit to build a simple UI on top of the OpenSea API

OpenSea API Explorer Using Streamlit to build a simple UI on top of the OpenSea API. ?? Contributing Contributions, issues and feature requests are we

Gavin Capriola 1 Jan 4, 2022
SOLSEA-NFT-EXPLORE - Using Streamlit to build a simple UI on top of the Solana API

SOLSEA NFT Explorer Using Streamlit to build a simple UI on top of the Solana AP

Devin Capriola 3 Mar 19, 2022
A python library built on the API of the coderHub.sa, which helps you to fetch the challenges and more

coderHub A python library built on the API of the coderHub.sa, which helps you to fetch the challenges and more Installation • Features • Usage • Lice

TheAwiteb 5 Nov 4, 2022
First Party data integration solution built for marketing teams to enable audience and conversion onboarding into Google Marketing products (Google Ads, Campaign Manager, Google Analytics).

Megalista Sample integration code for onboarding offline/CRM data from BigQuery as custom audiences or offline conversions in Google Ads, Google Analy

Google 76 Dec 29, 2022
PESU Academy Discord Bot built for PESsants and PESts of PES University

PESU Academy Bot PESU Academy Discord Bot built for PESsants and PESts of PES University You can add the bot to your Discord Server using this link. O

Aditeya Baral 0 Nov 16, 2021
Singer Tap for dbt Artifacts built with the Meltano SDK

tap-dbt-artifacts tap-dbt-artifacts is a Singer tap for dbtArtifacts. Built with the Meltano SDK for Singer Taps.

Prratek Ramchandani 9 Nov 25, 2022