Contrastive Learning of Structured World Models

Related tags

Deep Learning c-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Comments
  • Data Loading Error

    Data Loading Error

    When I try to run the code provided here, I end up hitting a RuntimeError here: https://github.com/tkipf/c-swm/blob/e944b24bcaa42d9ee847f30163437a50f0237aa0/train.py#L104 when running your Shapes-2D build/train functions: python train.py --dataset data/shapes_train.h5 --encoder small --name shapes

    Specifically, the error says: RuntimeError: DataLoader worker is killed by signal: Killed. It seems to be coming from the fact that the dataloader is having trouble multi-processing the training set loading. But when I look through your utils files, I am not seeing why this error would exist.

    This same error has occurred on both a windows machine as well as a linux machine.

    Here's the full trace:

    Traceback (most recent call last):
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 724, in _try_get_data
        data = self.data_queue.get(timeout=timeout)
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/multiprocessing/queues.py", line 104, in get
        if not self._poll(timeout):
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/multiprocessing/connection.py", line 257, in poll
        return self._poll(timeout)
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/multiprocessing/connection.py", line 414, in _poll
        r = wait([self], timeout)
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/multiprocessing/connection.py", line 921, in wait
        ready = selector.select(timeout)
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/selectors.py", line 415, in select
        fd_event_list = self._selector.poll(timeout)
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
        _error_if_any_worker_fails()
    RuntimeError: DataLoader worker (pid 19014) is killed by signal: Killed. 
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "train.py", line 104, in <module>
        obs = train_loader.__iter__().next()[0]
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 804, in __next__
        idx, data = self._get_data()
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 771, in _get_data
        success, data = self._try_get_data()
      File "/home/aadharna/miniconda3/envs/cswm/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 737, in _try_get_data
        raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
    RuntimeError: DataLoader worker (pid(s) 19012, 19014) exited unexpectedly
    (cswm) aadharna@penguin:~/PycharmProjects/c-swm$ 
    

    I'll continue poking around and update when I find the root cause.

    opened by aadharna 0
  • What are the

    What are the "interactions" in 2d-BoxPushing and 3d-BoxPushing?

    Hi, Kipf~I really appreciate your work. I am wondering if the GNN is modeling interactions between objects.

    After heading over the codes, I found that when collision happens from object A to object B, there's no bouncing back of object A and pushing forward happens to object B. So what are the interactions in 2d-boxpushing and 3d-boxingpushing experiments. The objects in those two experiments are more like independent with each other.

    For 3-body physics, I understand the interactions are the forces among spheres.

    Best regards.

    opened by Luodian 0
  • Add to path in physics.py

    Add to path in physics.py

    You need to copy the following lines from data_gen/env.py into data_gen/physics.py:

    # Get env directory
    import sys
    from pathlib import Path
    if str(Path.cwd()) not in sys.path:
        sys.path.insert(0, str(Path.cwd()))
    
    opened by csquires 0
  • Implementation of hinge loss does not match the paper

    Implementation of hinge loss does not match the paper

    Hello Kipf, I find there is a discrepancy between the loss mentioned in the paper.

    According to Eq(5) in paper, for negative samples, you calculate the Euclidean distance between negative state sample at timestep t and state at timestep t+1.

    However, in the code below, state and neg_state are both at timestep t.

    self.neg_loss = torch.max(
        zeros, self.hinge - self.energy(
            state, action, neg_state, no_trans=True)).mean()
    

    I noticed that the same question was also asked here.

    I want to know if this is a bug ? Does the discrepancy affect the final performance ?

    opened by Alxead 3
  • What do you use for baseline implementation?

    What do you use for baseline implementation?

    Do you all use unique code for baseline implementations of World Models (with AE and VAE) or do you use a publicly accessible library? Can you point me toward the library/code you used? Would be useful for duplicating results. (@abaheti95 you'll be interested too)

    opened by balloch 1
  • Any scripts for representation visualization?

    Any scripts for representation visualization?

    Hello, thank you for sharing the codes. I really like the Figures 3 and 4 in your paper, and wonder if there is any scripts for reproducing these results. Thanks!

    opened by NagisaZj 1
Owner
Thomas Kipf
Thomas Kipf
Saeed Lotfi 28 Dec 12, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 7, 2023
Simple and Effective Few-Shot Named Entity Recognition with Structured Nearest Neighbor Learning

structshot Code and data for paper "Simple and Effective Few-Shot Named Entity Recognition with Structured Nearest Neighbor Learning", Yi Yang and Arz

ASAPP Research 47 Dec 27, 2022
Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

THUNLP 75 Nov 2, 2022
The self-supervised goal reaching benchmark introduced in Discovering and Achieving Goals via World Models

Lexa-Benchmark Codebase for the self-supervised goal reaching benchmark introduced in 'Discovering and Achieving Goals via World Models'. Setup Create

null 1 Oct 14, 2021
World Models with TensorFlow 2

World Models This repo reproduces the original implementation of World Models. This implementation uses TensorFlow 2.2. Docker The easiest way to hand

Zac Wellmer 234 Nov 30, 2022
Learning Open-World Object Proposals without Learning to Classify

Learning Open-World Object Proposals without Learning to Classify Pytorch implementation for "Learning Open-World Object Proposals without Learning to

Dahun Kim 149 Dec 22, 2022
TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

null 98 Dec 15, 2022
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Manling Li 49 Nov 21, 2022
PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data.

A ready-to-use framework of latest models for structured (tabular) data learning with PyTorch. Applications include recommendation, CRT prediction, healthcare analytics, and etc.

null 48 Nov 30, 2022
This repo contains the official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 5, 2022
Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs

Perceiver IO Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs Usage import torch from src.perceiver.

Timur Ganiev 111 Nov 15, 2022
A Structured Self-attentive Sentence Embedding

Structured Self-attentive sentence embeddings Implementation for the paper A Structured Self-Attentive Sentence Embedding, which was published in ICLR

Kaushal Shetty 488 Nov 28, 2022
Deep Structured Instance Graph for Distilling Object Detectors (ICCV 2021)

DSIG Deep Structured Instance Graph for Distilling Object Detectors Authors: Yixin Chen, Pengguang Chen, Shu Liu, Liwei Wang, Jiaya Jia. [pdf] [slide]

DV Lab 31 Nov 17, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
Pytorch implementation of the paper Progressive Growing of Points with Tree-structured Generators (BMVC 2021)

PGpoints Pytorch implementation of the paper Progressive Growing of Points with Tree-structured Generators (BMVC 2021) Hyeontae Son, Young Min Kim Pre

Hyeontae Son 9 Jun 6, 2022
Structured Edge Detection Toolbox

################################################################### # # # Structure

Piotr Dollar 779 Jan 2, 2023
Label-Free Model Evaluation with Semi-Structured Dataset Representations

Label-Free Model Evaluation with Semi-Structured Dataset Representations Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch

null 8 Oct 6, 2022