Code for "Unsupervised State Representation Learning in Atari"

Overview

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


⚠️ Important ⚠️ : The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}
Comments
  • Pillow version problem

    Pillow version problem

    python -m scripts.run_probe --method infonce-stdim --env-name Pong-v0
    

    Traceback (most recent call last): File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/runpy.py", line 183, in _run_module_as_main mod_name, mod_spec, code = _get_module_details(mod_name, _Error) File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/runpy.py", line 109, in _get_module_details import(pkg_name) File "/home/duane/PycharmProjects/atari-representation-learning/scripts/init.py", line 1, in from .run_contrastive import train_encoder File "/home/duane/PycharmProjects/atari-representation-learning/scripts/run_contrastive.py", line 8, in from atariari.methods.dim_baseline import DIMTrainer File "/home/duane/PycharmProjects/atari-representation-learning/atariari/methods/dim_baseline.py", line 13, in from torchvision import transforms File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/init.py", line 4, in from torchvision import datasets File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/datasets/init.py", line 9, in from .fakedata import FakeData File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/datasets/fakedata.py", line 3, in from .. import transforms File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/init.py", line 1, in from .transforms import * File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 17, in from . import functional as F File "/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 5, in from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION ImportError: cannot import name 'PILLOW_VERSION' from 'PIL' (/home/duane/anaconda3/envs/atari-representation-learning/lib/python3.7/site-packages/PIL/init.py)

    opened by DuaneNielsen 9
  • Reproducing published score

    Reproducing published score

    Thank you very much for sharing this implementation! There is a minor issue, as far as i understand, here should be steps=args.pretraining_steps.

    I tried to run the experiment with python -m scripts.run_probe --method infonce-stdim --env-name MsPacmanNoFrameskip-v4 --pretraining-steps 100000 couple of times, mean f1 is ≈0.65. In the paper score is 0.7. Other methods also show a slightly lower score, although "supervised" one matches exactly. Am I missing some hyperparameter or it's just a typical score fluctuation?

    opened by htdt 6
  • Any tips on extracting RAM locations?

    Any tips on extracting RAM locations?

    I'd like to see another game or two in the list - do you have any pointers to source code for the games / tips based on previous games about how to find the RAM locations for useful information in the game? (I was initially worried, given the limited memory of the Atari 2600, that multiple pieces of information might be stored in separate bits of the same byte, but from your list, that doesn't seem to be the case, so perhaps it isn't too complicated?).

    opened by neighthan 5
  • Intuition on cross entropy.

    Intuition on cross entropy.

    https://github.com/mila-iqia/atari-representation-learning/blob/59f3e3b94c4a0a61a6cb63acf247dd5a8662fadd/atariari/methods/global_local_infonce.py#L110

    This line is a little confusing to me. How does the range over N samples correspond to our cross-entropy targets? All help is highly appreciated!

    opened by biggzlar 4
  • Add tile colors and game score annotations for Qbert

    Add tile colors and game score annotations for Qbert

    Adds tile color annotations to Qbert, since flipping the tile colors is the central objective of the game. Also adds score information, because the tile colors change in later levels, and eventually even require multiple flips. These additional annotations make the game much more playable from ARI features alone.

    opened by camall3n 3
  • Ram annotation of (x, y) are not aligned

    Ram annotation of (x, y) are not aligned

    Hi,

    I've tried using the provided AtariARIWrapper to get the 'labels' for MsPacmanNoFrameskip-v4. However, when I draw the locations of player, enemy_sue, enemy_inky..., using ('player_x', 'player_y'), ('enemy_sue_x', 'enemy_sue_y'), there seems to be always certain offsets to the real location.

    image

    image

    How can get the aligned locations? Thanks

    opened by happywu 3
  • Benchmark the pretrained-rl-agent method

    Benchmark the pretrained-rl-agent method

    I tried to reproduce the result in Table 10, with

    python -m scripts.run_probe --method pretrained-rl-agent
    

    but got

    -------Collecting samples----------
    Deleting room_number for being too low in entropy! Sorry, dood!
    Deleting enemy_skull_y for being too low in entropy! Sorry, dood!
    Deleting key_monster_x for being too low in entropy! Sorry, dood!
    Deleting key_monster_y for being too low in entropy! Sorry, dood!
    Deleting level for being too low in entropy! Sorry, dood!
    Deleting score_0 for being too low in entropy! Sorry, dood!
    Deleting score_1 for being too low in entropy! Sorry, dood!
    Deleting score_2 for being too low in entropy! Sorry, dood!
    Duplicates: 98, Test Len: 8011
    got episodes!
    Total Steps: 27411
    Traceback (most recent call last):
      File "/home/liuyuezhangadam/anaconda3/envs/pytorch/lib/python3.7/runpy.py", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/home/liuyuezhangadam/anaconda3/envs/pytorch/lib/python3.7/runpy.py", line 85, in _run_code
        exec(code, run_globals)
      File "/home/liuyuezhangadam/Git/atari-representation-learning/scripts/run_probe.py", line 87, in <module>
        run_probe(args)
      File "/home/liuyuezhangadam/Git/atari-representation-learning/scripts/run_probe.py", line 71, in run_probe
        trainer.train(tr_eps, val_eps, tr_labels, val_labels)
      File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 197, in train
        epoch_loss, accuracy = self.do_one_epoch(tr_eps, tr_labels)
      File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 143, in do_one_epoch
        preds = self.probe(x, k)
      File "/home/liuyuezhangadam/Git/atari-representation-learning/atariari/benchmark/probe.py", line 117,
    wandb: Waiting for W&B process to finish, PID 30113
     in probe
        assert len(f.squeeze().shape) == 2, "if input is a batch of vectors you must specify an encoder!"
    AssertionError: if input is a batch of vectors you must specify an encoder!
    

    seems the encoder is simply defined as None in the code,

    https://github.com/mila-iqia/atari-representation-learning/blob/017f9260fb53c45d34c80c5080a2b46148e4cfd6/scripts/run_probe.py#L36-L37

    any help for fixing it to reproduce the result in Table 10? Or did I make any mistake? Thanks, @ankeshanand

    opened by liuyuezhang 3
  • Fix bucketing issues

    Fix bucketing issues

    right now, the values of different labels can take on different ranges of values. These values (if the labels are positions) may not makes sense with downsampling of frames, so we need to fix this.

    opened by eracah 3
  • Generating Enough Episodes for Tennis

    Generating Enough Episodes for Tennis

    I've been trying to generate the same number of frames described in the paper for probing evaluations (35,000 train; 5,000 validation; 10,000 test), but for Tennis I am unable to do so because of the large number of duplicates in the collected episodes. May ask are how you were able to for the paper?

    I've tried collecting episodes with different random seeds and up to 400,000 steps (to try and account for duplicates), but so far to no avail. I've succeeded in generating episodes for other games but just not Tennis. I believe this might be because the agent spends the majority of episodes refusing to serve the ball (an issue I've come across previously before).

    tr_episodes, val_episodes,\
    tr_labels, val_labels,\
    test_episodes, test_labels = get_episodes(env_name='TennisNoFrameskip-v4', 
        steps=50000, 
        collect_mode="pretrained_ppo",
        seed=seed)
    
    opened by adamtupper 2
  • unused naff_fc_size parameter

    unused naff_fc_size parameter

    https://github.com/mila-iqia/atari-representation-learning/blob/aa37bf5858ca267ea0a0615e1a727308d5628e0d/atariari/methods/no_action_feedforward_predictor.py#L28-L30 In NaFFPredictor, the fc layer uses the feature_size from the encoder as the layer width, hence the naff_fc_size is left unused. Also, Sequential is used here, so it seems there might be multiple layers.

    opened by AnxietyYoungPoet 2
  • Should the weights and biases account be configurable?

    Should the weights and biases account be configurable?

    python -m scripts.run_probe --method infonce-stdim --env-name PongNoFrameskip-v4
    

    Requires a weights and biases account. OK. Account created.

    After setup, it seems that the account name (entity) is hardcoded.to "curl-atari"

    wandb.init(project=args.wandb_proj, entity="curl-atari", tags=tags)
    

    perhaps it should be something like

    wandb.init(project=args.wandb_proj, entity=args.wandb_entiti, tags=tags)
    

    Or am I mistaken?

    opened by DuaneNielsen 2
  • How to get ball direction/velocity in Atari Pong?

    How to get ball direction/velocity in Atari Pong?

    Hello, I was checking the ram annotations for Pong and it seems like only the score and position of the players and ball is known from the ram, my thinking was that they also stored ball direction in order to make it move frame by frame.

    Are those annotations just missing from that page or do they do that in some other way that does not use the ram?

    opened by pedrohpf 0
  • RAM annotation for Blue Ghosts in MsPacman

    RAM annotation for Blue Ghosts in MsPacman

    (The following is my observation through trial and error. So it may not be 100% correct.)

    In MsPacman, by default, ram[116] == 0. When MsPacman eats the Power Pellet, and the ghosts get blue, the lower 6 bits of ram[116] shows the remaining time, and the higher 2 bits show the number of the killed ghosts. For example, after two ghosts are killed, ram[116] == (remaining time) + 2 << 7.

    opened by MasWag 0
  • Can't reproduce experimental results

    Can't reproduce experimental results

    I tried to run probing tasks for different Atari environments, using the following command:

    python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

    I did not change any code, just tried different game, including PongNoFrameskip-v4, BowlingNoFrameskip-v4, BreakoutNoFrameskip-v4, HeroNoFrameskip-v4.

    However, only the F1 score for pong matches the score reported in the paper. The F1 scores of the other three games are far worse than the score shown in the paper (for bowling, I got 0.22).

    I check the training loss logged in wandb, it seems that training has not converged at all. See the figure below.

    training loss

    How to get the F1 socres reported in the paper? Am I missing something?

    opened by Alxead 5
  • robot x/y coordinate mismatch in Berzerk game dict

    robot x/y coordinate mismatch in Berzerk game dict

    https://github.com/mila-iqia/atari-representation-learning/blob/f5e1080fe2077356b409cd324d6847b03ed35308/atariari/benchmark/ram_annotations.py#L68

    In the Berzerk game dictionary there are 8 indices corresponding to robot_x coordinates and 9 indices corresponding to robot_y coordinates.

    enemy_robots_x=range(65, 73),
    enemy_robots_y=range(56, 65),
    

    Is this a bug/typo?

    opened by zacharyhorvitz 0
  • Incorrect/ambiguous features in Seaquest

    Incorrect/ambiguous features in Seaquest

    I feel the features extracted for seaquest are incorrect/ambiguous. In the attached image, there are

    1. 4 enemies with one of them being at a different position than the others but the feature for enemy_obstacle is same for all 4 of them i.e. 96.
    2. Also as we can see, there is no diver in the current frame, rather one of the enemies has shot a missile. The extractor is mis-labeling that missile as a diver in 'diver_x_0': 45.

    image

    {'labels': {'player_y': 13, 'oxygen_meter_value': 64, 'num_lives': 3, 'missile_direction': 0, 'diver_x_1': 0, 'player_direction': 0, 'diver_x_0': 45, 'player_x': 76, 'enemy_obstacle_x_3': 96, 'diver_x_3': 0, 'missile_x': 0, 'score_0': 0, 'diver_x_2': 0, 'enemy_obstacle_x_1': 96, 'enemy_obstacle_x_0': 96, 'score_1': 0, 'enemy_obstacle_x_2': 96, 'divers_collected_count': 0}, 'ale.lives': 4}

    Request you to please look into this or let us know if we are misinterpreting anything here.

    Thanks, Vaibhav

    opened by damnOblivious 3
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 6, 2022
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 6, 2023
This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

KGI (Knowledge Graph Induction) for slot filling This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code fo

International Business Machines 72 Jan 6, 2023
Convert Python 3 code to CUDA code.

Py2CUDA Convert python code to CUDA. Usage To convert a python file say named py_file.py to CUDA, run python generate_cuda.py --file py_file.py --arch

Yuval Rosen 3 Jul 14, 2021
Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

Bayesian Methods Research Group 56 Nov 15, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Low-code/No-code approach for deep learning inference on devices

EzEdgeAI A concept project that uses a low-code/no-code approach to implement deep learning inference on devices. It provides a componentized framewor

On-Device AI Co., Ltd. 7 Apr 5, 2022
Code for all the Advent of Code'21 challenges mostly written in python

Advent of Code 21 Code for all the Advent of Code'21 challenges mostly written in python. They are not necessarily the best or fastest solutions but j

null 4 May 26, 2022
Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

This codebase is being actively maintained, please create and issue if you have issues using it Basics All data files are included under losses and ea

J K Terry 32 Nov 9, 2021
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

Launch Platform 16 Oct 11, 2022
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Nidhal Baccouri 3k Jan 5, 2023
Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Pytorch Lightning 1k Jan 2, 2023
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Michael Nielsen 13.9k Dec 26, 2022
Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Daniel Seita 121 Dec 30, 2022
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 50 Dec 3, 2022