Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"

Overview

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint

Louay Hazami   ·   Rayhane Mama   ·   Ragavan Thurairatnam


MIT license PWC PWC PWC PWC PWC PWC PWC PWC

Efficient-VDVAE is a memory and compute efficient very deep hierarchical VAE. It converges faster and is more stable than current hierarchical VAE models. It also achieves SOTA likelihood-based performance on several image datasets.

Pre-trained model checkpoints

We provide checkpoints of pre-trained models on MNIST, CIFAR-10, Imagenet 32x32, Imagenet 64x64, CelebA 64x64, CelebAHQ 256x256 (5-bits and 8-bits), FFHQ 256x256 (5-bits and 8bits), CelebAHQ 1024x1024 and FFHQ 1024x1024 in the links in the table below. All provided models are the ones trained for table 4 of the paper.

Dataset Pytorch JAX Negative ELBO
Logs Checkpoints Logs Checkpoints
MNIST link link link link 79.09 nats
CIFAR-10 Queued Queued link link 2.87 bits/dim
Imagenet 32x32 link link link link 3.58 bits/dim
Imagenet 64x64 link link link link 3.30 bits/dim
CelebA 64x64 link link link link 1.83 bits/dim
CelebAHQ 256x256 (5-bits) link link link link 0.51 bits/dim
CelebAHQ 256x256 (8-bits) link link link link 1.35 bits/dim
FFHQ 256x256 (5-bits) link link link link 0.53 bits/dim
FFHQ 256x256 (8-bits) link link link link 2.17 bits/dim
CelebAHQ 1024x1024 link link link link 1.01 bits/dim
FFHQ 1024x1024 link link link link 2.30 bits/dim

Notes:

  • Downloading from the "Checkpoints" link will download the minimal required files to resume training/do inference. The minimal files are the model checkpoint file and the saved hyper-parameters of the run (explained further below).
  • Downloading from the "Logs" link will download additional pre-training logs such as tensorboard files or saved images from training. "Logs" also holds the saved hyper-parameters of the run.
  • Downloaded "Logs" and/or "Checkpoints" should be always unzipped in their implementation folder (efficient_vdvae_torch for Pytorch checkpoints and efficient_vdvae_jax for JAX checkpoints).
  • Some of the model checkpoints are missing in either Pytorch or JAX for the moment. We will update them soon.

Pre-requisites

To run this codebase, you need:

  • Machine that runs a linux based OS (tested on Ubuntu 20.04 (LTS))
  • GPUs (preferably more than 16GB)
  • Docker
  • Python 3.7 or higher
  • CUDA 11.1 or higher (can be installed from here)

We recommend running all the code below inside a Linux screen or any other terminal multiplexer, since some commands can take hours/days to finish and you don't want them to die when you close your terminal.

Note:

  • If you're planning on running the JAX implementation, the installed JAX must use exactly the same CUDA and Cudnn versions installed. Our default Dockerfile assumes the code will run with CUDA 11.4 or newer and should be changed otherwise. For more details, refer to JAX installation.

Installation

To create the docker image used in both the Pytorch and JAX implementations:

cd build  
docker build -t efficient_vdvae_image .  

Note:

  • If using JAX library on ampere architecture GPUs, it's possible to face a random GPU hanging problem when training on multiple GPUs (issue). In that case, we provide an alternative docker image with an older version of JAX to bypass the issue until a solution is found.

All code executions should be done within a docker container. To start the docker container, we provide a utility script:

sh docker_run.sh  # Starts the container and attaches terminal
cd /workspace/Efficient-VDVAE  # Inside docker container

Setup datasets

All datasets can be automatically downloaded and pre-processed from the convenience script we provide:

cd data_scripts
sh download_and_preprocess.sh <dataset_name>

Notes:

  • <dataset_name> can be one of (imagenet32, imagenet64, celeba, celebahq, ffhq). MNIST and CIFAR-10 datasets will get automatically downloaded later when training the model, and they do no require any dataset setup.
  • For the celeba dataset, a manual download of img_align_celeba.zip and list_eval_partition.txt files is necessary. Both files should be placed under <project_path>/dataset_dumps/.
  • img_align_celeba.zip download link.
  • list_eval_partition.txt download link.

Setting the hyper-parameters

In this repository, we use hparams library (already included in the Dockerfile) for hyper-parameter management:

  • Specify all run parameters (number of GPUs, model parameters, etc) in one .cfg file
  • Hparams evaluates any expression used as "value" in the .cfg file. "value" can be any basic python object (floats, strings, lists, etc) or any python basic expression (1/2, max(3, 7), etc.) as long as the evaluation does not require any library importations or does not rely on other values from the .cfg.
  • Hparams saves the configuration of previous runs for reproducibility, resuming training, etc.
  • All hparams are saved by name, and re-using the same name will recall the old run instead of making a new one.
  • The .cfg file is split into sections for readability, and all parameters in the file are accessible as class attributes in the codebase for convenience.
  • The HParams object keeps a global state throughout all the scripts in the code.

We highly recommend having a deeper look into how this library works by reading the hparams library documentation, the parameters description and figures 4 and 5 in the paper before trying to run Efficient-VDVAE.

We have heavily tested the robustness and stability of our approach, so changing the model/optimization hyper-parameters for memory load reduction should not introduce any drastic instabilities as to make the model untrainable. That is of course as long as the changes don't negate the important stability points we describe in the paper.

Training the Efficient-VDVAE

To run Efficient-VDVAE in Torch:

cd efficient_vdvae_torch  
# Set the hyper-parameters in "hparams.cfg" file  
# Set "NUM_GPUS_PER_NODE" in "train.sh" file  
sh train.sh  

To run Efficient-VDVAE in JAX:

cd efficient_vdvae_jax  
# Set the hyper-parameters in "hparams.cfg" file  
python train.py  

If you want to run the model with less GPUs than available on the hardware, for example 2 GPUs out of 8:

CUDA_VISIBLE_DEVICES=0,1 sh train.sh  # For torch  
CUDA_VISIBLE_DEVICES=0,1 python train.py  # For JAX  

Models automatically create checkpoints during training. To resume a model from its last checkpoint, set its <run.name> in hparams.cfg file and re-run the same training commands.

Since training commands will save the hparams of the defined run in the .cfg file. If trying to restart a pre-existing run (by re-using its name in hparams.cfg), we provide a convenience script for resetting saved runs:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
sh reset.sh <run.name>  # <run.name> is the first field in hparams.cfg  

Note:

  • To make things easier for new users, we provide example hparams.cfg files that can be used under the egs folder. Detailed description of the role of each parameter is also inside hparams.cfg.
  • Hparams in egs are to be viewed only as guiding examples, they are not meant to be exactly similar to pre -trained checkpoints or experiments done in the paper.
  • While the example hparams under the naming convention ..._baseline.cfg are not exactly the hparams of C2 models in the paper (pre-trained checkpoints), they are easier to design models that achieve the same performance and can be treated as equivalents to C2 models.

Monitoring the training process

While writing this codebase, we put extra emphasis on verbosity and logging. Aside from the printed logs on terminal (during training), you can monitor the training progress and keep track of useful metrics using Tensorboard:

# While outside efficient_vdvae_torch or efficient_vdvae_jax  
# Run outside the docker container
tensorboard --logdir . --port <port_id> --reload_multifile True  

In the browser, navigate to localhost:<port_id> to visualize all saved metrics.

If Tensorboard is not installed (outside the docker container):

pip install --upgrade tensorboard

Inference with the Efficient-VDVAE

Efficient-VDVAE support multiple inference modes:

  • "reconstruction": Encodes then decodes the test set images and computes test NLL and SSIM.
  • "generation": Generates random images from the prior distribution. Randomness is controlled by the run.seed parameter.
  • "div_stats": Pre-computes the average KL divergence stats used to determine turned-off variates (refer to section 7 of the paper). Note: This mode needs to be run before "encoding" mode and before trying to do masked "reconstruction" (Refer to hparams.cfg for a detailed description).
  • "encoding": Extracts the latent distribution from the inference model, pruned to the quantile defined by synthesis.variates_masks_quantile parameter. This latent distribution is usable in downstream tasks.

To run the inference:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
# Set the inference mode in "logs-<run.name>/hparams-<run.name>.cfg"  
# Set the same <run.name> in "hparams.cfg"  
python synthesize.py  

Notes:

  • Since training a model with a name <run.name> will save that configuration under logs-<run.name>/hparams-<run.name>.cfg for reproducibility and error reduction. Any changes that one wants to make during inference time need to be applied on the saved hparams file (logs-<run.name>/hparams-<run.name>.cfg) instead of the main file hparams.cfg.
  • The torch implementation currently doesn't support multi-GPU inference. The JAX implementation does.

Potential TODOs

  • Make data loaders Out-Of-Core (OOC) in Pytorch
  • Make data loaders Out-Of-Core (OOC) in JAX
  • Update pre-trained model checkpoints
  • Add Fréchet-Inception Distance (FID) and Inception Score (IS) as measures for sample quality performance.
  • Improve the format of the encoded dataset used in downstream tasks (output of encoding mode, if there is a need)
  • Write a decoding mode API (if needed).

Bibtex

If you happen to use this codebase, please cite our paper:

@article{hazami2022efficient,
  title={Efficient-VDVAE: Less is more},
  author={Hazami, Louay and Mama, Rayhane and Thurairatnam, Ragavan},
  journal={arXiv preprint arXiv:2203.13751},
  year={2022}
}
Comments
  • Compact models, configs, and checkpoints

    Compact models, configs, and checkpoints

    Hi!

    I've been looking further into the code to understand the hyperparemeter choices for the different configurations discussed in the paper. Do I understand correctly that C1 corresponds to the configs ..._compact whereas C2 corresponds to the configs ..._baseline? I've been trying to understand the C1 configs with varying width in more details but noticed that the _compact config files seem to specify models with fewer parameters than specified in Table 3 of the paper. For example, I've tried to compute the rough number of parameters for the imagenet32_baseline.cfg and get the 156M parameters for model C2 specified in the table. However, imagenet32_compact.cfg yields ~20M parameters in total, while Table 3 specifies ~52M parameters for C1.

    Therefore, I just wanted to ask how the configs you provide map to the models presented in the paper and whether you could provide the configs for C1 and/or a checkpoint file for one of the compact models?

    Thank you very much, I really appreciate it! Best wishes, Matthias

    question 
    opened by msbauer 4
  • Mismatch between config and logs (for Cifar10)?

    Mismatch between config and logs (for Cifar10)?

    Thank you for providing the code! I was looking at your configs for more details about the architectures used and noticed a mismatch to the logs provided for Cifar10 (I haven't checked it for any of the other experiments). In the config you specify the number of steps as 800k but the tensorboard logs seem to have 1.1M steps. Could you update the config to match the logs or clarify what I'm misunderstanding? Thank you very much!

    bug question 
    opened by msbauer 3
  • comparison between jax vs torch

    comparison between jax vs torch

    Thanks for the great work!!!!

    I am curious about the performance comparison between Jax and torch implementation, specifically, the training speed and NLL. Do you use jax or torch for the results in the paper?

    question 
    opened by qsh-zh 2
  • Workaround for loading hparams without tf 1.x ?

    Workaround for loading hparams without tf 1.x ?

    Colab no longer supports tensorflow 1.x

    Is there a workaround for loading the hparams cfg?

    I tried hparams pypi package but it didn't work.

    Also the following didn't work:

    # https://github.com/tensorflow/community/issues/148
    from tensorboard.plugins.hparams import api as hp
    hparams = hp.HParams('.', name="efficient_vdvae")
    
    
    
    opened by turian 1
  • som questions about checkpoint

    som questions about checkpoint

    when I download pytorch checkpoint, I have a zip file and unzip it,finnaly I got a directory like ---archive --data --data.pkl --version --logs-celebAHQ1024_baseline --hparams-celebahq1024_baseline.cfg

    but there is no checkpoint file in it ......

    when I download the jax checkpoint , I have a checkpoint file(it has no postfix pth) but when torch.load(xxx), raise error: _pickle.UnpicklingError: unpickling stack underflow

    opened by miaoYuanyuan 0
  • Simple L2 reconstruction loss?

    Simple L2 reconstruction loss?

    Thx for the excellent work! Currently, it seems that we are using some obscure DiscMixLogistic reconstruction loss. Is there any guide on using simple L2 reconstruction loss? Do I need to change the model architecture for that?

    question 
    opened by SilenceMonk 5
  • Unconditional generation without dataset

    Unconditional generation without dataset

    I want to do unconditional generation from checkpoints and logs WITHOUT needing to download the original dataset. How?

    
    Traceback (most recent call last):
      File "synthesize.py", line 345, in <module>
        main()
      File "synthesize.py", line 329, in main
        data_loader = synth_data()
      File "synthesize.py", line 271, in synth_data
        return synth_generic_data()
      File "/content/Efficient-VDVAE/efficient_vdvae_torch/data/generic_data_loader.py", line 143, in synth_generic_data
        synth_images, synth_filenames = create_filenames_list(hparams.data.synthesis_data_path)
      File "/content/Efficient-VDVAE/efficient_vdvae_torch/data/generic_data_loader.py", line 68, in create_filenames_list
        filenames = sorted(os.listdir(path))
    FileNotFoundError: [Errno 2] No such file or directory: '../datasets/celebAHQ/val_data/'
    
    question 
    opened by turian 1
  • Incorrect Imagenet dataset

    Incorrect Imagenet dataset

    I downloaded the Imagenet dataset linked in this repo, and I think the dataset for it (50000 test images with labels, box downsampling) doesn't match the official Imagenet 32x32/64x64 versions used for NLL benchmarks (https://github.com/openai/vdvae/blob/main/setup_imagenet.sh, 49999 test images with no labels, can download from https://academictorrents.com/details/96816a530ee002254d29bf7a61c0c158d3dedc3b). Difference in downsampling method used during pre-processing will make the NLL's not comparable.

    bug 
    opened by prafullasd 1
  • Any plan to compute FID&IS scores?

    Any plan to compute FID&IS scores?

    Hi! It is a great pleasure to meet an incredible work! I appreciate your dedication to this community.

    I have checked the paper and github, but it seems that there is no sample generation performance, such as Frechet Inception Distance (FID) or Inception Score (IS). Is there any plan to compute FID or IS in near future?

    Thank you, Dongjun Kim.

    question 
    opened by Kim-Dongjun 5
  • Discussion starter

    Discussion starter

    This umbrella issue tracks our work's current state and discusses the priority of potential TODOs. It is also a good place to ask any questions about the work.

    Goals

    • [x] Enable very deep VAE based models to train faster and with less compute, while only applying the simplest modifications.
    • [ ] Provide all 22 pre-trained models in both Pytorch and JAX. (Will be added soon)

    Potential TODOs (based on need)

    • [x] Make data loaders Out-Of-Core (OOC) in Pytorch (For RAM efficiency on large datasets)
    • [x] Make data loaders Out-Of-Core (OOC) in JAX (For RAM efficiency on large datasets)
    • [ ] Add Fréchet-Inception Distance (FID) and Inception Score (IS) as measures for sample quality performance.
    • [ ] Improve the format of the encoded dataset used in downstream tasks (output of encoding mode, if there is a need)
    • [ ] Write a decoding mode API (if needed).

    Notes:

    • Any feedback or questions on code, documentation or paper are most welcome.
    • Any suggestions to improve this repository and any requests for additional useful features are also welcome.
    • There are no plans of implementing Efficient-VDVAE in tensorflow (TF) as we faced graph scalability limits on TF since models are very deep.
    • We have heavily tested the robustness and stability of our approach, so changing the model/optimization hyper-parameters for memory load reduction should not introduce any drastic instabilities as to make the model untrainable. That is of course as long as the changes don't negate the important stability points we describe in the paper.

    Thank you for considering our work. Please feel free to reach out! :)

    Open discussion 
    opened by Rayhane-mamah 1
Owner
Rayhane Mama
- If it seems impossible, then it's worth doing.
Rayhane Mama
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

Google Research 166 Nov 13, 2022
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Alex Rogozhnikov 6.1k Nov 30, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow ?? Transformers provides thousands of pretrained mo

Hugging Face 75.2k Dec 1, 2022
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 123 Nov 22, 2022
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax

ProGen - (wip) Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily

Phil Wang 71 Dec 1, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Nov 28, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 60 Nov 4, 2022
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth

Soumik Rakshit 11 Jul 24, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Nov 30, 2022
Extending JAX with custom C++ and CUDA code

Extending JAX with custom C++ and CUDA code This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in

Dan Foreman-Mackey 237 Nov 25, 2022
Model parallel transformers in Jax and Haiku

Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo

Ben Wang 4.6k Dec 2, 2022
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

null 120 Nov 18, 2022
Hardware accelerated, batchable and differentiable optimizers in JAX.

JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be

Google 601 Dec 4, 2022