A Tensorfflow implementation of Attend, Infer, Repeat

Overview

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models

This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR), as presented in the following paper: S. M. Ali Eslami et. al., Attend, Infer, Repeat: Fast Scene Understanding with Generative Models.

  • Author (of the implementation): Adam Kosiorek, Oxford Robotics Institue, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

I describe the implementation and the issues I run into while working on it in this blog post.

Installation

Install Tensorflow v1.1.0rc1, Sonnet v1.1 and the following dependencies (using pip install -r requirements.txt (preferred) or pip install [package]):

  • matplotlib==1.5.3
  • numpy==1.12.1
  • attrdict==2.0.0
  • scipy==0.18.1

Sample Results

AIR learns to reconstruct objects by painting them one by one in a blank canvas. The below figure comes from a model trained for 175k iterations; the maximum number of steps is set to 3, but there are never more than 2 objects. The first row shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 4-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is "0 with ..." written above it, it means that this step was not used.

AIR results

Data

Run ./scripts/create_dataset.sh The script creates train and validation datasets of multi-digit MNIST.

Training

Run ./scripts/train_multi_mnist.sh The training script will run for 300k iteratios and will save model checkpoints and training progress figures every 10k iterations in results/multi_mnist. Tensorflow summaries are also stored in the same folder and Tensorboard can be used for monitoring.

The model seems to be very sensitive to initialisation. It might be necessary to run training multiple times before achieving count step accuracy close to the one reported in the paper.

Experimentation

The jupyter notebook available at attend_infer_repeat/experiment.ipynb can be used for experimentation.

Citation

If you find this repo useful in your research, please consider citing the original paper:

@incollection{Eslami2016,
    title = {Attend, Infer, Repeat: Fast Scene Understanding with Generative Models},
    author = {Eslami, S. M. Ali and Heess, Nicolas and Weber, Theophane and Tassa, Yuval and Szepesvari, David and kavukcuoglu, koray and Hinton, Geoffrey E},
    booktitle = {Advances in Neural Information Processing Systems 29},
    editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett},
    pages = {3225--3233},
    year = {2016},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models.pdf}
}

License

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

Release Notes

Version 1.0

  • Original unofficial implementation; contains the multi-digit MNIST experiment.
You might also like...
A fast Evolution Strategy implementation in Python

Evostra: Evolution Strategy for Python Evolution Strategy (ES) is an optimization technique based on ideas of adaptation and evolution. You can learn

🌳 A Python-inspired implementation of the Optimum-Path Forest classifier.

OPFython: A Python-Inspired Optimum-Path Forest Classifier Welcome to OPFython. Note that this implementation relies purely on the standard LibOPF. Th

Implementation of Geometric Vector Perceptron, a simple circuit for 3d rotation equivariance for learning over large biomolecules, in Pytorch. Idea proposed and accepted at ICLR 2021
Implementation of Geometric Vector Perceptron, a simple circuit for 3d rotation equivariance for learning over large biomolecules, in Pytorch. Idea proposed and accepted at ICLR 2021

Geometric Vector Perceptron Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biom

Official implementation of AAAI-21 paper
Official implementation of AAAI-21 paper "Label Confusion Learning to Enhance Text Classification Models"

Description: This is the official implementation of our AAAI-21 accepted paper Label Confusion Learning to Enhance Text Classification Models. The str

Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

PyTorch implementation of
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

An essential implementation of BYOL in PyTorch + PyTorch Lightning
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021].  https://arxiv.org/pdf/2101.12378.pdf
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

A PyTorch re-implementation of the paper 'Exploring Simple Siamese Representation Learning'. Reproduced the 67.8% Top1 Acc on ImageNet.

Exploring simple siamese representation learning This is a PyTorch re-implementation of the SimSiam paper on ImageNet dataset. The results match that

Comments
  • Num Steps Prior Fix and other bugfixes

    Num Steps Prior Fix and other bugfixes

    The original paper uses a geometric prior, whereas the implementation here used a stick-breaking process prior with the form of the approximate posterior distribution. The prior here is restored to a geometric one, which for big values of success probability resembles a high-entropy uniform prior at the start of training, which encourages exploration in the number of steps.

    Other fixes:

    • MLP bug fix: every MLP had only a single hidden layer
    opened by akosiorek 0
  • Unable to install Sonnet v1.1

    Unable to install Sonnet v1.1

    The required Sonnet version is hard to come by. The oldest version that can be installed via pip is 1.7 but that requires a more recent Tensorflow version compared to the one used in the project. Therefore installing Sonnet v1.1 requires building it from source via Bazel. I have not found any guidance on which Bazel and Java versions are required for this. I tried both Bazel v0.4.5 (minimum version required for Sonnet) and v0.20.0 (latest Bazel version) with Java 10 without success.

    Possible solutions:

    • Could the installation instructions be updated to include more detail with regards to how to install Sonnet v1.1?
    • Would it be possible to update the repository to support more recent versions of Sonnet and Tensorflow?
    opened by martinengelcke 0
  • Update requirements

    Update requirements

    Remove numpy from requirements as tensorflow already pulls in the correct version. Add pillow to the list of dependencies as it is needed by imresize in scipy.misc in attend_infer_repeat/data/data.py.

    opened by martinengelcke 0
Owner
Adam Kosiorek
I'm a PhD student at the Oxford Robotics Institute. I work on Machine Learning for perception - I'm looking into external memory and attention for RNNs.
Adam Kosiorek
A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI

MPItrampoline MPI wrapper library: MPI trampoline library: MPI integration tests: MPI is the de-facto standard for inter-node communication on HPC sys

Erik Schnetter 31 Dec 22, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... λͺ¨λΈμ˜ κ°œλ…μ΄ν•΄λ₯Ό 돕기 μœ„ν•œ κ΅¬ν˜„λ¬Όλ‘œ ν˜„μž¬ λ³€μˆ˜λͺ…을 μƒμ„Ένžˆ μ μ—ˆκ³ 

BG Kim 3 Oct 6, 2022
Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.

NuPIC Numenta Platform for Intelligent Computing The Numenta Platform for Intelligent Computing (NuPIC) is a machine intelligence platform that implem

Numenta 6.3k Dec 30, 2022
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

null 770 Jan 2, 2023
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

null 121 Nov 5, 2022
The project is an official implementation of our CVPR2019 paper "Deep High-Resolution Representation Learning for Human Pose Estimation"

Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) News [2020/07/05] A very nice blog from Towards Data Science introd

Leo Xiao 3.9k Jan 5, 2023
Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

William Falcon 141 Dec 30, 2022
Python implementation of cover trees, near-drop-in replacement for scipy.spatial.kdtree

This is a Python implementation of cover trees, a data structure for finding nearest neighbors in a general metric space (e.g., a 3D box with periodic

Patrick Varilly 28 Nov 25, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow

xRBM Library Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow Installation Using pip: pip install xrbm Examples Tut

Omid Alemi 55 Dec 29, 2022