Adaptive Attention Span for Reinforcement Learning

Overview

Adaptive Transformers in RL

Official implementation of Adaptive Transformers in RL

In this work we replicate several results from Stabilizing Transformers for RL on both Pong and rooms_select_nonmatching_object from DMLab30.

We also extend the Stable Transformer architecture with Adaptive Attention Span on a partially observable (POMDP) setting of Reinforcement Learning. To our knowledge this is one of the first attempts to stabilize and explore Adaptive Attention Span in an RL domain.

Steps to replicate what we did on your own machine

  1. Downloading DMLab:

  2. Downloading Atari: Getting Started with Gym– http://gym.openai.com/docs/#getting-started-with-gym

  3. Execution notes:

  • The experiments take around 4 hours on 32vCPUs and 2 P100 GPUs for 6 million environment interactions. To run without a GPU, use the flag “--disable_cuda”.
  • For more details on other flags, see the top of train.py (include a link to this file) which has descriptions for each.
  • All experiments use a slightly revised version of IMPALA from torchbeast

Snippets

Best performing adaptive attention span model on “rooms_select_nonmatching_object”:

python train.py --total_steps 20000000 \
--learning_rate 0.0001 --unroll_length 299 --num_buffers 40 --n_layer 3 \
--d_inner 1024 --xpid row85 --chunk_size 100 --action_repeat 1 \
--num_actors 32 --num_learner_threads 1 --sleep_length 20 \
--level_name rooms_select_nonmatching_object --use_adaptive \
--attn_span 400 --adapt_span_loss 0.025 --adapt_span_cache

Best performing Stable Transformer on Pong:

python train.py --total_steps 10000000 \
--learning_rate 0.0004 --unroll_length 239 --num_buffers 40 \
--n_layer 3 --d_inner 1024 --xpid row82 --chunk_size 80 \
--action_repeat 1 --num_actors 32 --num_learner_threads 1 \
--sleep_length 5 --atari True

Best performing Stable Transformer on “rooms_select_nonmatching_object”:

python train.py --total_steps 20000000 \
--learning_rate 0.0001 --unroll_length 299 \
--num_buffers 40 --n_layer 3 --d_inner 1024 \
--xpid row79 --chunk_size 100 --action_repeat 1 \
--num_actors 32 --num_learner_threads 1 --sleep_length 20 \
--level_name rooms_select_nonmatching_object  --mem_len 200

Reference

If you find this repository useful, do cite it with,

@article{kumar2020adaptive,
    title={Adaptive Transformers in RL},
    author={Shakti Kumar and Jerrod Parker and Panteha Naderian},
    year={2020},
    eprint={2004.03761},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Get Steps 0 @ 0.0 SPS. Loss inf. Stats

    Get Steps 0 @ 0.0 SPS. Loss inf. Stats

    Hello, I tried your command but with 16 num_ actors: python train.py --total_steps 10000000 --learning_rate 0.0004 --unroll_length 239 --num_buffers 40 --n_layer 3 --d_inner 1024 --xpid row82 --chunk_size 80 --action_repeat 1 --num_actors 32 --num_learner_threads 1 --sleep_length 5 --atari True

    But I got: Steps 0 @ 0.0 SPS. Loss inf. Stats

    opened by ghost 8
  • train.py on colab, using snippets

    train.py on colab, using snippets

    Hi, I have a question about snippets in README. I plan to use your code in colab. To test the code, I executed train.py with your snippets. However, train.py is not working well. The result showed that loss is infinite and stats are empty. Could I know how to solve the problem?

    opened by yhg8423 4
  • Error reproducing results

    Error reproducing results

    There seems to be a lot of bugs in the code especially the padding and end of sequence indexing. Can you please update the repo with bug free code used to reproduce results in paper. Also, can you share the exact requirements.txt or a Dockerfile. Thanks.

    opened by doltonfernandes 0
  • Isn't param

    Isn't param "--use_gate" important for Pong?

    Hey, in the paper StablizingTransformer..., there is a gate unit in the moudle of the GTrxl, but in your default params, the --use_gate is False. Why?

    opened by weihongwei0586 0
  • Regarding logic for first done indexes

    Regarding logic for first done indexes

    Hi, Thanks for the code and the paper on using adaptive attention span in RL. In train.py, I haven't understood the logic for calculating ind_first_done in following line:
    https://github.com/jerrodparker20/adaptive-transformers-in-rl/blob/6f75366b78998fb1d8755acd2d851c461c82ee75/train.py#L1240 .

    After going through the loss calculations and learn function where ind_first_done is used, I feel line: https://github.com/jerrodparker20/adaptive-transformers-in-rl/blob/6f75366b78998fb1d8755acd2d851c461c82ee75/train.py#L1240 should be as follows: ind_first_done = padding_mask.long().argmax(0) + 1 . I feel so because from the comments, ind_first_done denotes the final index in each trajectory.

    Could you kindly explain the logic used for the mentioned snippet?

    opened by victor-psiori 1
  • Stable Transformer on Pong

    Stable Transformer on Pong

    Hello,

    I am currently unable to recreate the results of the stable transformer on the Pong environment. I believe from the paper the last 100 episode returns should be ~17.62 for this model and environment.

    I am running the train program with arguments as specified in README for Best Performing Stable Transformer on Pong.

    In train.py line 731 I changed ctx = mp.get_context("fork") to ctx = mp.get_context("spawn")

    The final results I obtained one one run:

    [INFO:17181 train:962 2020-12-01 19:35:33,350] Steps 10001513 @ 668.5 SPS. Loss -15.672254. Return per episode: -12.7. Stats:
    {'baseline_loss': 11.395485877990723,
     'entropy_loss': -18.699639002482098,
     'episode_returns': [-20.0, -18.0, -19.0],
     'last_100_episode_returns': -19.530000686645508,
     'learning_rate': 8.657589688233862e-05,
     'len_max_traj': 239,
     'max_return_achieved': '-14.0 at step 5366379',
     'mean_episode_return': -12.666666666666666,
     'num_unpadded_steps': 3346,
     'pg_loss': -8.368099212646484,
     'total_loss': -15.672253926595053}
    [INFO:17181 train:969 2020-12-01 19:35:33,350] Learning finished after 10001513 steps.
    

    Results from another run:

    [INFO:15271 train:962 2020-12-04 19:47:48,776] Steps 10001156 @ 661.4 SPS. Loss -9.595014. Return per episode: -19.7. Stats:
    {'baseline_loss': 14.119840621948242,
     'entropy_loss': -18.633128484090168,
     'episode_returns': [-21.0, -19.0, -20.0, -19.0],
     'last_100_episode_returns': -19.540000915527344,
     'learning_rate': 9.02709105067138e-05,
     'len_max_traj': 239,
     'max_return_achieved': '-14.0 at step 7824133',
     'mean_episode_return': -19.666666666666668,
     'num_unpadded_steps': 3309,
     'pg_loss': -5.081725597381592,
     'total_loss': -9.595013936360678}
    [INFO:15271 train:969 2020-12-04 19:47:48,776] Learning finished after 10001156 steps.
    

    I am on Ubuntu 18.04.4, using Cuda 10.2, cudnn 7, torch 1.6.0.

    Thanks in advance for any help.

    Best, Sean

    opened by furmans 2
  • Is this algorithm suitable for off-policy policy?

    Is this algorithm suitable for off-policy policy?

    I just finished reading your paper, and I notice that it is an on policy method.
    And I wondering if anyone has tested it with an rl method that has a replay_buff pool.
    As far as I know, for off-policy method with RNN structure(like lstm, gru or attention or transformer...), if hidden state is stored with a sample (s,a,r,s'), the hidden state would become a stale data after a long training--- Is this issue conqured with adaptive-transformer?

    opened by dbsxdbsx 1
Owner
null
ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin et al., 2020).

ReConsider ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin

Facebook Research 47 Jul 26, 2022
The code for two papers: Feedback Transformer and Expire-Span.

transformer-sequential This repo contains the code for two papers: Feedback Transformer Expire-Span The training code is structured for long sequentia

Facebook Research 125 Dec 25, 2022
SpanNER: Named EntityRe-/Recognition as Span Prediction

SpanNER: Named EntityRe-/Recognition as Span Prediction Overview | Demo | Installation | Preprocessing | Prepare Models | Running | System Combination

NeuLab 104 Dec 17, 2022
Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study

Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study Supplementary Materials for Kentaro Matsuura, Junya Honda, Imad

Kentaro Matsuura 4 Nov 1, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation

AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation A pytorch-version implementation codes of paper:

null 11 Dec 13, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Phil Wang 189 Nov 22, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Attention-driven Robot Manipulation (ARM) which includes Q-attention

Attention-driven Robotic Manipulation (ARM) This codebase is home to: Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation I

Stephen James 84 Dec 29, 2022
Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Context Terms

LESA Introduction This repository contains the official implementation of Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Cont

Chenglin Yang 20 Dec 31, 2021
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 5, 2023
Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

null 120 Dec 15, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

null 195 Dec 7, 2022
Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021)

Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021, official Pytorch implementatio

Microsoft 247 Dec 25, 2022