Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Overview

Clockwork VAEs in JAX/Flax

Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported from the official TensorFlow implementation.

Running on a single TPU v3, training is 10x faster than reported in the paper (60h -> 6h on minerl).

Method

Clockwork VAEs are deep generative model that learn long-term dependencies in video by leveraging hierarchies of representations that progress at different clock speeds. In contrast to prior video prediction methods that typically focus on predicting sharp but short sequences in the future, Clockwork VAEs can accurately predict high-level content, such as object positions and identities, for 1000 frames.

Clockwork VAEs build upon the Recurrent State Space Model (RSSM), so each state contains a deterministic component for long-term memory and a stochastic component for sampling diverse plausible futures. Clockwork VAEs are trained end-to-end to optimize the evidence lower bound (ELBO) that consists of a reconstruction term for each image and a KL regularizer for each stochastic variable in the model.

Instructions

This repository contains the code for training the Clockwork VAE model on the datasets minerl, mazes, and mmnist.

The datasets will automatically be downloaded into the --datadir directory.

python3 train.py --logdir /path/to/logdir --datadir /path/to/datasets --config configs/<dataset>.yml 

The evaluation script writes open-loop video predictions in both PNG and NPZ format and plots of PSNR and SSIM to the data directory.

python3 eval.py --logdir /path/to/logdir

Known differences from the original

  • Flax' default kernel initializer, layer precision and GRU implementation (avoiding redundant biases) are used.
  • For some configuration parameters, only the defaults are implemented.
  • Training metrics and videos are logged with wandb.
  • The base configuration is in config.py.

Added features:

  • This implementation runs on TPU out-of-the-box.
  • Apart from the config file, configuration can be done via command line and wandb.
  • Matching the seed of a previous run will exactly repeat it.

Things to watch out for

Replication of paper results for the mazes dataset has not been confirmed yet.

Getting evaluation metrics is a memory bottleneck during training, due to the large eval_seq_len. If you run out of device memory, consider lowering it during training, for example to 100. Remember to pass in the original value to eval.py to get unchanged results.

Acknowledgements

Thanks to Vaibhav Saxena and Danijar Hafner for helpful discussions and to Jamie Townsend for reviewing code.

You might also like...
Modeling Category-Selective Cortical Regions with Topographic Variational Autoencoders

Modeling Category-Selective Cortical Regions with Topographic Variational Autoencoders

Data Augmentation with Variational Autoencoders
Data Augmentation with Variational Autoencoders

Documentation Pyraug This library provides a way to perform Data Augmentation using Variational Autoencoders in a reliable way even in challenging con

PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory
The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory

This repository contains the software implementation of most algorithms used or developed in my research. The LaTeX and Python code for generating the

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x  flax community week
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week

RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa

Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun

ARAE Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 Disc

Comments
  • potential bug in the encoder

    potential bug in the encoder

    for level in range(1, self.c.levels): for _ in range(self.c.enc_dense_layers - 1): x = nn.relu(nn.Dense(self.c.enc_dense_embed_size)(x)) if self.c.enc_dense_layers > 0: x = nn.Dense(feat_size)(x) layer = x

    line 39 onwards in the cnn.py Encoder(), the depth of these layers increases with the level as the hidden variables is overwritten. At large n_levels and n_enc_dense_layers this will result in a very deep network mapping from the observation embedding to the latent space. Not sure it's intentional, doesn't seem to have a purpose, ie is there a reason the higher latent spaces need a deeper function to map from the embedding?

    Same issue in the original tensorflow version https://github.com/vaibhavsaxena11/cwvae/issues/2

    opened by xmax1 1
  • File encoding error

    File encoding error

    I get the following error while running train.py: ffmpeg_error

    I get this error even though ffmpeg is already installed in the environment. Any idea how I can solve this issue? Thanks.

    opened by sb93 1
Owner
Julius Kunze
Let's create helpful intelligent machines.
Julius Kunze
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 2, 2023
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022