Pretraining Representations For Data-Efficient Reinforcement Learning

Related tags

Deep Learning SGI
Overview

Pretraining Representations For Data-Efficient Reinforcement Learning

Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville

This repo provides code for implementing SGI.

Install

To install the requirements, follow these steps:

# PyTorch
export LANG=C.UTF-8
# Install requirements
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# Finally, install the project
pip install --user -e .

Usage:

The default branch for the latest and stable changes is release.

  • To run SGI:
  1. Download the DQN replay dataset from https://research.google/tools/datasets/dqn-replay/
    • Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals.
  2. To pretrain with SGI:
python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \
    env.game=pong seed=1 offline_model_save={your model name} \
    offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \
    offline.runner.no_eval=1 \
    +offline.algo.goal_weight=1 \
    +offline.algo.inverse_model_weight=1 \
    +offline.algo.spr_weight=1 \
    +offline.algo.target_update_tau=0.01 \
    +offline.agent.model_kwargs.momentum_tau=0.01 \
    do_online=False \
    algo.batch_size=256 \
    +offline.agent.model_kwargs.noisy_nets_std=0 \
    offline.runner.dataloader.dataset_on_disk=True \
    offline.runner.dataloader.samples=1000000 \
    offline.runner.dataloader.checkpoints='{your checkpoints}' \
    offline.runner.dataloader.num_workers=2 \
    offline.runner.dataloader.data_path={your data dir} \
    offline.runner.dataloader.tmp_data_path=./ 
  1. To fine-tune with SGI:
python -m scripts.run public=True env.game=pong seed=1 num_logs=10  \
    model_load={your_model_name} model_folder=./ \
    algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1

When reporting scores, we average across 10 fine-tuning seeds.

./scripts/experiments contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. Each of these scripts can be launched by providing a game and seed, e.g., ./scripts/experiments/sgim_pretrain.sh pong 1. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories.

Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details.

What does each file do?

.
├── scripts
│   ├── run.py                # The main runner script to launch jobs.
│   ├── config.yaml           # The hydra configuration file, listing hyperparameters and options.
|   └── experiments           # Configurations for various experiments done by SGI.
|   
├── src                     
│   ├── agent.py              # Implements the Agent API for action selection 
│   ├── algos.py              # Distributional RL loss and optimization
│   ├── models.py             # Forward passes, network initialization.
│   ├── networks.py           # Network architecture and forward passes.
│   ├── offline_dataset.py    # Dataloader for offline data.
│   ├── gcrl.py               # Utils for SGI's goal-conditioned RL objective.
│   ├── rlpyt_atari_env.py    # Slightly modified Atari env from rlpyt
│   ├── rlpyt_utils.py        # Utility methods that we use to extend rlpyt's functionality
│   └── utils.py              # Command line arguments and helper functions 
│
└── requirements.txt          # Dependencies
You might also like...
Efficient Sparse Attacks on Videos using Reinforcement Learning
Efficient Sparse Attacks on Videos using Reinforcement Learning

EARL This repository provides a simple implementation of the work "Efficient Sparse Attacks on Videos using Reinforcement Learning" Example: Demo: Her

A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

CLASP - Contrastive Language-Aminoacid Sequence Pretraining

CLASP - Contrastive Language-Aminoacid Sequence Pretraining Repository for creating models pretrained on language and aminoacid sequences similar to C

TSP: Temporally-Sensitive Pretraining of Video Encoders for Localization Tasks
TSP: Temporally-Sensitive Pretraining of Video Encoders for Localization Tasks

TSP: Temporally-Sensitive Pretraining of Video Encoders for Localization Tasks [Paper] [Project Website] This repository holds the source code, pretra

Official Pytorch Implementation of:
Official Pytorch Implementation of: "ImageNet-21K Pretraining for the Masses"(2021) paper

ImageNet-21K Pretraining for the Masses Paper | Pretrained models Official PyTorch Implementation Tal Ridnik, Emanuel Ben-Baruch, Asaf Noy, Lihi Zelni

[NAACL & ACL 2021] SapBERT: Self-alignment pretraining for BERT.
[NAACL & ACL 2021] SapBERT: Self-alignment pretraining for BERT.

SapBERT: Self-alignment pretraining for BERT This repo holds code for the SapBERT model presented in our NAACL 2021 paper: Self-Alignment Pretraining

 ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information
ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information

ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information This repository contains code, model, dataset for ChineseBERT at ACL2021. Ch

 DETReg: Unsupervised Pretraining with Region Priors for Object Detection
DETReg: Unsupervised Pretraining with Region Priors for Object Detection

DETReg: Unsupervised Pretraining with Region Priors for Object Detection Amir Bar, Xin Wang, Vadim Kantorov, Colorado J Reed, Roei Herzig, Gal Chechik

Code for generating a single image pretraining dataset
Code for generating a single image pretraining dataset

Single Image Pretraining of Visual Representations As shown in the paper A critical analysis of self-supervision, or what we can learn from a single i

Comments
  • some detail regarding dataloader

    some detail regarding dataloader

    Hi,

    Thanks for sharing the code. I have some question about the way you use frame and action to make predictions.

    The input to the collate function has the following dimensions

    observation.shape = [256, 20, 84, 84] action.shape. = [256, 20]

    I assume that for a given index, the observation corresponds to the current observation, and the action corresponds to the current action. For example, observation[0][0] and action[0][0] correspond to the (observation, action) for batch 0, index 0

    If I understand correctly, the collate function stacks the frames and return the following (simplified, where batch is ignored) format for the indices:

    observation: [[frame1, frame2, frame3, frame4], [frame2, frame3, frame4, frame5], ....] action: [action for frame4, action for frame5 ....] reward: [reward for frame4, reward for frame5 ....]

    However, in your code for next latent prediction, you seem to be using the current frame and the next action to predict the next frame.

    For example, I think you are using [frame1, frame2, frame3, frame4] and action for frame 5 to compute the latent for [frame2, frame3, frame4, frame5]

    Shouldn't you be using the current action as opposed to the next action in conjunction with the current frame? Or did I misunderstand something?

    Thanks for the clarification, Kevin

    opened by kevinghst 3
  • sticky action used during finetuning?

    sticky action used during finetuning?

    Hi!

    Quick question - did you use sticky actions during finetuning?

    The paper doesn't say. Also, by looking at your code, it seems that the sticky action is disabled: https://github.com/mila-iqia/SGI/blob/master/src/rlpyt_atari_env.py#L75

    Thanks for your clarification, Kevin

    opened by kevinghst 1
  • Problem w/ installing the project

    Problem w/ installing the project

    Hi,

    Thanks for sharing this repo.

    I encountered an issue installing the dependencies.

    In particular when I execute pip install --user -e ., I get the following error: ERROR: File "setup.py" or "setup.cfg" not found. Directory cannot be installed in editable mode: /scratch/wz1232/SGI

    Any idea how to fix this? Does the repo need to include the file setup.py as mentioned in the error? Thanks, Kevin

    opened by kevinghst 0
  • mismatch between paper and repo hyperparameters

    mismatch between paper and repo hyperparameters

    Hi,

    There seems to be at least two mismatch between the paper and repo hyperparameters.

    1. SPR_weight In paper: "We set λ SPR = 2 and λ IM = 1 during pre-training. Unless otherwise noted, all settings match SPR during fine-tuning, including batch size, replay ratio, target network update period, and λ SPR" (in the SPR paper λ SPR = 2 as well).

    However, in /sgim_pretrain.sh SPR weight is set to 1. On the other hand, the SPR weight is not set in /sgiml_finetune.sh, which means it uses the default in config, which is 5.

    1. momentum_Tau during finetuning In paper, it says finetuning hyperparameters are same as ones used in SPR. In SPR, no EMA is used (tau = 0) when augmentation is used. However, in the code tau is not set in /sigml_finetune.sh, which defaults to 0.01 in config.yaml

    Am I correct to assume that the repo versions are incorrect?

    These 2 are the only ones I can find, but I worry that I may have missed some. I would really appreciate if any of you can help take a second look, so that others like myself can reliably reproduce your result!

    Thanks, Kevin

    opened by kevinghst 1
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
Proto-RL: Reinforcement Learning with Prototypical Representations

Proto-RL: Reinforcement Learning with Prototypical Representations This is a PyTorch implementation of Proto-RL from Reinforcement Learning with Proto

Denis Yarats 74 Dec 6, 2022
CURL: Contrastive Unsupervised Representations for Reinforcement Learning

CURL Rainbow Status: Archive (code is provided as-is, no updates expected) This is an implementation of CURL: Contrastive Unsupervised Representations

Aravind Srinivas 46 Dec 12, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

Yuki M. Asano 249 Dec 22, 2022
Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data - Official PyTorch Implementation (CVPR 2022)

Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data (CVPR 2022) Potentials of primitive shapes f

null 31 Sep 27, 2022
When Does Pretraining Help? Assessing Self-Supervised Learning for Law and the CaseHOLD Dataset of 53,000+ Legal Holdings

When Does Pretraining Help? Assessing Self-Supervised Learning for Law and the CaseHOLD Dataset of 53,000+ Legal Holdings This is the repository for t

RegLab 39 Jan 7, 2023
[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning

SoCo [NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning By Fangyun Wei*, Yue Gao*, Zhirong Wu, Han Hu,

Yue Gao 139 Dec 14, 2022
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 4, 2023
An efficient framework for reinforcement learning.

rl: An efficient framework for reinforcement learning Requirements Introduction PPO Test Requirements name version Python >=3.7 numpy >=1.19 torch >=1

null 16 Nov 30, 2022