Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Related tags

Deep Learning nppac
Overview

Non-Parametric Prior Actor-Critic (N-PPAC)

This repository contains the code for

On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations, Tim G. J. Rudner*, Cong Lu*, Michael A. Osborne, Yarin Gal, Yee Whye Teh. Conference on Neural Information Processing Systems (NeurIPS), 2021.

Abstract: KL-regularized reinforcement learning from expert demonstrations has proved successful in improving the sample efficiency of deep reinforcement learning algorithms, allowing them to be applied to challenging physical real-world tasks. However, we show that KL-regularized reinforcement learning with behavioral policies derived from expert demonstrations suffers from hitherto unrecognized pathological behavior that can lead to slow, unstable, and suboptimal online training. We show empirically that the pathology occurs for commonly chosen behavioral policy classes and demonstrate its impact on sample efficiency and online policy performance. Finally, we show that the pathology can be remedied by specifying non-parametric behavioral policies and that doing so allows KL-regularized RL to significantly outperform state-of-the-art approaches on a variety of challenging locomotion and dexterous hand manipulation tasks.

View on OpenReview

In particular, the code implements:

  • Scripts for estimating behavioral reference policies for a range of model calsses, including non-parametric Gaussian processes, Bayesian neural networks trained via MC Dropout, deep ensembles, and Gaussian neural density models;
  • Scripts for KL-regularized online training that uses different bahevioral expert policies.

How to use this package

We provide a Docker setup which may be built as follows:

docker build -t torch-nppac .

To train the GP policies offline:

bash exp_scripts/paper_clone_gp.sh

To run online training (N-PPAC):

bash exp_scripts/paper_configs.sh

Pre-trained GP policies using final_clone_gp.sh are provided in the folder nppac/trained_gps/.

By default, all data will be stored in data/.

Reference

If you found this repository useful, please cite our paper as follows:

@inproceedings{
    rudner2021pathologies,
    title={On Pathologies in {KL}-Regularized Reinforcement Learning from Expert Demonstrations},
    author={Tim G. J. Rudner and Cong Lu and Michael A. Osborne and Yarin Gal and Yee Whye Teh},
    booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
    year={2021},
    url={https://openreview.net/forum?id=sS8rRmgAatA}
}

License

The repository is based on RLkit, which may contain further useful scripts. The license for this is contained under the rlkit/ folder.

You might also like...
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering
Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation

Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation The skip connections in U-Net pass features from the levels of enc

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Code for the paper
Code for the paper "Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are in envir

Code for the paper
Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are

A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

Implementing DeepMind's Fast Reinforcement Learning paper
Implementing DeepMind's Fast Reinforcement Learning paper

Fast Reinforcement Learning This is a repo where I implement the algorithms in the paper, Fast reinforcement learning with generalized policy updates.

PGPortfolio: Policy Gradient Portfolio, the source code of "A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem"(https://arxiv.org/pdf/1706.10059.pdf).

This is the original implementation of our paper, A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem (arXiv:1706.1

Comments
  • How to improve the evaluation efficiency?

    How to improve the evaluation efficiency?

    I used the following code to evaluate the pretrained model, but found that the evaluation efficiency is too low (about 23min per episode). I wondered if there is anything wrong with the code? Or could you please provide a standard script for evaluation? Thanks for your help!

    import mj_envs
    import gym
    import numpy as np
    import torch
    import gpytorch
    from gp_models import MultitaskGPModel
    from rlkit.torch.pytorch_util import set_gpu_mode
    from tqdm import tqdm
    import copy
    import time
    device = torch.device('cuda:1')
    
    def rollout(
            env,
            agent,
            max_path_length=np.inf,
            render=False,
            render_kwargs=None,
            preprocess_obs_for_policy_fn=None,
            get_action_kwargs=None,
            return_dict_obs=False,
            full_o_postprocess_func=None,
            reset_callback=None,
    ):
        if render_kwargs is None:
            render_kwargs = {}
        if get_action_kwargs is None:
            get_action_kwargs = {}
        if preprocess_obs_for_policy_fn is None:
            preprocess_obs_for_policy_fn = lambda x: x
        raw_obs = []
        raw_next_obs = []
        observations = []
        actions = []
        rewards = []
        terminals = []
        dones = []
        agent_infos = []
        env_infos = []
        next_observations = []
        path_length = 0
        # agent.reset()
        o = env.reset()
        if reset_callback:
            reset_callback(env, agent, o)
        if render:
            # todo: debug
            env.mj_render()
            # env.render(**render_kwargs)
        while path_length < max_path_length:
            print('path_length:', path_length)
            raw_obs.append(o)
            # todo: debug
    
            # o_for_agent = torch.from_numpy(o).cuda().float().unsqueeze(0)
    
            o_torch = torch.from_numpy(np.array([o])).float().to(device)
            output = model(o_torch)
            observed_pred = likelihood(output)
            a = observed_pred.mean.data.cpu().numpy()
    
            if len(a) == 1:
                a = a[0]
    
            # # o_for_agent = o
            # # a = agent.get_action(o_for_agent, **get_action_kwargs)
            # a, *_ = agent(o_for_agent, **get_action_kwargs)
            # a = a.detach().cpu().numpy()
            # # a = agent.get_action(o_for_agent, **get_action_kwargs)[0][0]
            agent_info = None
            if full_o_postprocess_func:
                full_o_postprocess_func(env, agent, o)
    
            next_o, r, done, env_info = env.step(copy.deepcopy(a))
            if render:
                # todo: debug
                env.mj_render()
    
                # env.render(**render_kwargs)
            observations.append(o)
            rewards.append(r)
            terminal = False
            if done:
                # terminal=False if TimeLimit caused termination
                if not env_info.pop('TimeLimit.truncated', False):
                    terminal = True
            terminals.append(terminal)
            dones.append(done)
            actions.append(a)
            next_observations.append(next_o)
            raw_next_obs.append(next_o)
            agent_infos.append(agent_info)
            env_infos.append(env_info)
            path_length += 1
            if done:
                break
            o = next_o
        actions = np.array(actions)
        if len(actions.shape) == 1:
            actions = np.expand_dims(actions, 1)
        observations = np.array(observations)
        next_observations = np.array(next_observations)
        if return_dict_obs:
            observations = raw_obs
            next_observations = raw_next_obs
        rewards = np.array(rewards)
        if len(rewards.shape) == 1:
            rewards = rewards.reshape(-1, 1)
        return dict(
            observations=observations,
            actions=actions,
            rewards=rewards,
            next_observations=next_observations,
            terminals=np.array(terminals).reshape(-1, 1),
            dones=np.array(dones).reshape(-1, 1),
            agent_infos=agent_infos,
            env_infos=env_infos,
            full_observations=raw_obs,
            full_next_observations=raw_obs,
        )
    
    
    def simulate_policy(env, policy, T=100, H=200, gpu=True, render=False):
        if gpu:
            set_gpu_mode(True)
            # policy.cuda()
            policy.to(device)
            print('use GPU')
        # policy = MakeDeterministic(policy)
        episode = 0
        success_time = 0
        env.seed(1)
        for episode in tqdm(range(0, T)):
            print('episode:{}'.format(episode))
            path = rollout(
                env,
                policy,
                max_path_length=H,
                render=render,
            )
            if path['env_infos'][-1]['goal_achieved'] is True:
                success_time += 1
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics([path])
            time.sleep(0.02)
        success_time /= episode
        return success_time
    
    
    
    env = gym.make(f'door-binary-v0')
    
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    data_set = '../d4rl_model/offpolicy_hand_data/door2_sparse.npy'
    model_path = '../nppac/nppac/door/gp_door_multitask_1000.pt'
    
    data = np.load(data_set, allow_pickle=True)
    keep_num = 1000
    use_ard = True
    gp_type = 'multitask'
    gp_rank = 1
    kernel_type = 'matern12'
    # Ablation to randomly filter the dataset, not active by default.
    if keep_num < len(data):
        print(f'Keeping {keep_num} trajectories.')
        data = np.random.choice(data, keep_num, replace=False)
    
    if type(data[0]['observations'][0]) is dict:
        # Convert to just the states
        for traj in data:
            traj['observations'] = [t['state_observation'] for t in traj['observations']]
    
    train_x = torch.from_numpy(np.array([j for i in [traj['observations'] for traj in data] for j in i])).float().to(
        device)
    train_y = torch.from_numpy(np.array([j for i in [traj['actions'] for traj in data] for j in i])).float().to(
        device)
    
    print('Data Loaded!')
    
    # Initialize likelihood and model
    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=action_dim).to(device)
    likelihood.eval()
    ard_num_dims = obs_dim if use_ard else None
    
    model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks=action_dim, rank=gp_rank,
                             ard_num_dims=ard_num_dims, kernel_type=kernel_type).to(device)
    
    model_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(model_dict)
    model.eval()
    
    
    success_rate = simulate_policy(env, model, render=False, T=100)
    print('success rate is :', success_rate)
    
    opened by nuomizai 2
  • How much GPU memory did nppac require during pretrain-stage?

    How much GPU memory did nppac require during pretrain-stage?

    When I tried to run paper_clone_gp.sh locally by removing the exp_scripts/run-gpy.sh, I got a 'CUDA out of memory' error. My command is as follows:

    gpu_id=0
    env_name=door
    data_set='expert_demonstration_data/door2_sparse.npy'
    kt=matern12
    python nppac/clone_from_dataset_gp.py --name $env_name --save_policies --data_set $data_set \
        --use_gpu --use_ard --gp_rank 1 --kernel_type $kt --save_dir $env_name
    

    Then I got an out-of-memory error as:

    RuntimeError: CUDA out of memory. Tried to allocate 131.26 GiB (GPU 0; 23.70 GiB total capacity; 1.87 GiB already allocated; 14.61 GiB free; 2.56 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    

    This error happened at line 50 in clone_from_dataset_gp.py as observed_pred = likelihood(model(o_torch)) where the o_torch.size() is only [1, 39]

    I'd like to know is this script able to run with a 20GB GPU locally. Or is there something wrong with the settings in the script? Thank you!

    opened by nuomizai 2
  • How to get the smooth curve in Fig.5

    How to get the smooth curve in Fig.5

    Hi, I'm very interested in NPPAC and I'd like to take it as a comparison method. However, I got stuck in the evaluation metric. How to get the smooth curve as that in Fig. 5? I can only get a unsmooth curve with rllab. image

    opened by nuomizai 1
  • How to visualize the result in Fig.5

    How to visualize the result in Fig.5

    I'd like to know how to visualize the result in Fig.5 in your paper? I have used the rllab to visualize the result, but I don't know which metric (Y-Axis Attribute) to show.

    opened by nuomizai 1
Owner
Cong Lu
DPhil Student in Autonomous Intelligent Machines and Systems @ Oxford
Cong Lu
ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge (ManiSkill Challenge), a large-scale learning-from-demonstrations benchmark for object manipulation.

ManiSkill-Learn ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge, a large-scale learning-from-dem

Hao Su's Lab, UCSD 48 Dec 30, 2022
Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun

ARAE Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 Disc

Junbo (Jake) Zhao 399 Jan 2, 2023
Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT

CheXbert: Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT CheXbert is an accurate, automated dee

Stanford Machine Learning Group 51 Dec 8, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Disagreement-Regularized Imitation Learning

Due to a normalization bug the expert trajectories have lower performance than the rl_baseline_zoo reported experts. Please see the following link in

Kianté Brantley 25 Apr 28, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer

Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer Paper on arXiv Public PyTorch implementation of two-stage peer-reg

NNAISENSE 38 Oct 14, 2022
(IEEE TIP 2021) Regularized Densely-connected Pyramid Network for Salient Instance Segmentation

RDPNet IEEE TIP 2021: Regularized Densely-connected Pyramid Network for Salient Instance Segmentation PyTorch training and testing code are available.

Yu-Huan Wu 41 Oct 21, 2022
R-Drop: Regularized Dropout for Neural Networks

R-Drop: Regularized Dropout for Neural Networks R-drop is a simple yet very effective regularization method built upon dropout, by minimizing the bidi

null 756 Dec 27, 2022