Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

You might also like...
Character Controllers using Motion VAEs

Character Controllers using Motion VAEs This repo is the codebase for the SIGGRAPH 2020 paper with the title above. Please find the paper and demo at

Official implementation of the paper "Topographic VAEs learn Equivariant Capsules"

Topographic Variational Autoencoder Paper: https://arxiv.org/abs/2109.01394 Getting Started Install requirements with Anaconda: conda env create -f en

Pytorch implementation of VAEs for heterogeneous likelihoods.

Heterogeneous VAEs Beware: This repository is under construction 🛠️ Pytorch implementation of different VAE models to model heterogeneous data. Here,

Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

 GAN JAX - A toy project to generate images from GANs with JAX
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x  flax community week
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week

RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa

Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Comments
  • The training stops in the middle

    The training stops in the middle

    Hi,

    It seems I am still unable to run it. I tried both cfar10 and ffhq256 both stop in the middle although in different stages. I am running on a machine with two 3090 so I think at least on cifar10 it should be good enough:

    CIFAR 10

    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ python train.py --hps cifar10
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: adam_beta1, value: 0.90000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: adam_beta2, value: 0.90000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: axis_visualize, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: bottleneck_multiple, value: 0.25000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: conv_precision, value: default
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: custom_width_str, value:
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: data_root, value: ./
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: dataset, value: cifar10
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: dec_blocks, value: 1x1,4m1,4x2,8m4,8x5,16m8,16x10,32m16,32x21
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: desc, value: test
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: device_count, value: 2
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: ema_rate, value: 0.99990
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: enc_blocks, value: 32x11,32d2,16x6,16d2,8x6,8d2,4x3,4d4,1x3
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: epochs_per_eval, value: 10
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: grad_clip, value: 200.00000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: host_count, value: 1
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: host_id, value: 0
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: hps, value: cifar10
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: image_channels, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: image_size, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_ckpt, value: 25000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_images, value: 10000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_print, value: 1000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_save, value: 10000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: log_wandb, value: False
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: logdir, value: ./saved_models/test/log
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: lr, value: 0.00020
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: n_batch, value: 16
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: no_bias_above, value: 64
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_depths_visualize, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_epochs, value: 10000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_images_visualize, value: 8
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_mixtures, value: 10
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_temperatures_visualize, value: 3
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_variables_visualize, value: 6
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: restore_path, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: save_dir, value: ./saved_models/test
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed, value: 0
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_eval, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_init, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_sample, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_train, value: None
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: skip_threshold, value: 400.00000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: test_eval, value: False
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: warmup_iters, value: 100
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: wd, value: 0.01000
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: width, value: 384
    time: Sat Mar 13 14:21:09 2021, type: hparam, key: zdim, value: 16
    time: Sat Mar 13 14:21:09 2021, message: training model test on cifar10
    time: Sat Mar 13 14:21:47 2021, total_params: 39145700, readable: 39,145,700
    time: Sat Mar 13 14:30:08 2021, model: test, type: train_loss, lr: 0.00000, epoch: 0, step: 0, elbo: -48.51558, elbo_filtered: -48.51558, grad_norm: 207.70775, kl: 37.04228, kl_nans: 0, log_likelihood: -11.47330, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 498.63470
    time: Sat Mar 13 14:31:20 2021, message: printing samples to ./saved_models/test/samples-0.png
    time: Sat Mar 13 14:31:21 2021, model: test, type: train_loss, lr: 0.00000, epoch: 0, step: 1, elbo: -48.48836, elbo_filtered: -48.48836, grad_norm: 207.70775, kl: 36.87710, kl_nans: 0, log_likelihood: -11.61126, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.42617
    time: Sat Mar 13 14:31:55 2021, message: printing samples to ./saved_models/test/samples-1.png
    time: Sat Mar 13 14:31:59 2021, model: test, type: train_loss, lr: 0.00002, epoch: 0, step: 8, elbo: -42.48382, elbo_filtered: -42.48382, grad_norm: 207.70775, kl: 30.93100, kl_nans: 0, log_likelihood: -11.55281, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.77924
    time: Sat Mar 13 14:32:33 2021, message: printing samples to ./saved_models/test/samples-8.png
    time: Sat Mar 13 14:32:37 2021, model: test, type: train_loss, lr: 0.00003, epoch: 0, step: 16, elbo: -34.54240, elbo_filtered: -34.54240, grad_norm: 207.70775, kl: 23.11177, kl_nans: 0, log_likelihood: -11.43063, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.78281
    time: Sat Mar 13 14:33:12 2021, message: printing samples to ./saved_models/test/samples-16.png
    time: Sat Mar 13 14:33:19 2021, model: test, type: train_loss, lr: 0.00006, epoch: 0, step: 32, elbo: -25.47084, elbo_filtered: -25.47084, grad_norm: 207.70775, kl: 14.20382, kl_nans: 0, log_likelihood: -11.26701, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40239
    time: Sat Mar 13 14:33:54 2021, message: printing samples to ./saved_models/test/samples-32.png
    time: Sat Mar 13 14:34:09 2021, model: test, type: train_loss, lr: 0.00013, epoch: 0, step: 64, elbo: -18.77669, elbo_filtered: -18.77669, grad_norm: 207.70775, kl: 7.75593, kl_nans: 0, log_likelihood: -11.02076, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40384
    time: Sat Mar 13 14:34:44 2021, message: printing samples to ./saved_models/test/samples-64.png
    time: Sat Mar 13 14:35:13 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 128, elbo: -14.41902, elbo_filtered: -14.41902, grad_norm: 207.70775, kl: 3.98318, kl_nans: 0, log_likelihood: -10.43585, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.39954
    time: Sat Mar 13 14:35:48 2021, message: printing samples to ./saved_models/test/samples-128.png
    time: Sat Mar 13 14:36:47 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 256, elbo: -11.45558, elbo_filtered: -11.45558, grad_norm: 207.70775, kl: 2.06421, kl_nans: 0, log_likelihood: -9.39137, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40935
    time: Sat Mar 13 14:37:23 2021, message: printing samples to ./saved_models/test/samples-256.png
    time: Sat Mar 13 14:39:21 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 512, elbo: -9.25702, elbo_filtered: -9.25702, grad_norm: 207.70775, kl: 1.13132, kl_nans: 0, log_likelihood: -8.12570, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.41006
    time: Sat Mar 13 14:39:57 2021, message: printing samples to ./saved_models/test/samples-512.png
    time: Sat Mar 13 14:43:43 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 1000, elbo: -7.79192, elbo_filtered: -7.79192, grad_norm: 179.25816, kl: 0.68122, kl_nans: 0, log_likelihood: -7.11070, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.45444
    time: Sat Mar 13 14:43:55 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 1024, elbo: -7.26243, elbo_filtered: -7.26243, grad_norm: 179.25816, kl: 0.28492, kl_nans: 0, log_likelihood: -6.97751, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.45458
    time: Sat Mar 13 14:44:30 2021, message: printing samples to ./saved_models/test/samples-1024.png
    2021-03-13 14:48:14.080987: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 1 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
    2021-03-13 14:48:14.081137: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 0 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
    Traceback (most recent call last):
      File "train.py", line 213, in <module>
        main()
      File "train.py", line 208, in main
        train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
      File "train.py", line 132, in train_loop
        optimizer, ema = p_synchronize((optimizer, ema))
    jax._src.traceback_util.FilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
    
    The stack trace above excludes JAX-internal frames.
    The following is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "train.py", line 213, in <module>
        main()
      File "train.py", line 208, in main
        train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
      File "train.py", line 132, in train_loop
        optimizer, ema = p_synchronize((optimizer, ema))
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/api.py", line 1582, in f_pmapped
        out = pxla.xla_pmap(
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1453, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1385, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1456, in process
        return trace.process_map(self, fun, tracers, params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 625, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 621, in xla_pmap_impl
        return compiled_fun(*args)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1168, in execute_replicated
        out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
    RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
    

    FFHQ256

    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ python train.py --hps ffhq256
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: adam_beta1, value: 0.90000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: adam_beta2, value: 0.90000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: axis_visualize, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: bottleneck_multiple, value: 0.25000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: conv_precision, value: default
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: custom_width_str, value:
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: data_root, value: ./
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: dataset, value: ffhq_256
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: dec_blocks, value: 1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x21,64m32,64x13,128m64,128x7,256m128
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: desc, value: test
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: device_count, value: 2
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: ema_rate, value: 0.99900
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: enc_blocks, value: 256x3,256d2,128x8,128d2,64x12,64d2,32x17,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: epochs_per_eval, value: 1
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: grad_clip, value: 130.00000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: host_count, value: 1
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: host_id, value: 0
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: hps, value: ffhq256
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: image_channels, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: image_size, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_ckpt, value: 25000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_images, value: 10000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_print, value: 1000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_save, value: 10000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: log_wandb, value: False
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: logdir, value: ./saved_models/test/log
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: lr, value: 0.00015
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: n_batch, value: 1
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: no_bias_above, value: 64
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_depths_visualize, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_epochs, value: 10000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_images_visualize, value: 2
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_mixtures, value: 10
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_temperatures_visualize, value: 1
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_variables_visualize, value: 3
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: restore_path, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: save_dir, value: ./saved_models/test
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed, value: 0
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_eval, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_init, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_sample, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_train, value: None
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: skip_threshold, value: 180.00000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: test_eval, value: False
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: warmup_iters, value: 100
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: wd, value: 0.01000
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: width, value: 512
    time: Sat Mar 13 13:23:59 2021, type: hparam, key: zdim, value: 16
    time: Sat Mar 13 13:23:59 2021, message: training model test on ffhq_256
    time: Sat Mar 13 13:25:23 2021, total_params: 114874852, readable: 114,874,852
    2021-03-13 13:48:17.142157: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
    ********************************
    Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
    Compiling module pmap_training_step.315621
    ********************************
    2021-03-13 13:49:08.604007: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 1 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
    2021-03-13 13:49:08.607849: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 0 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
    Traceback (most recent call last):
      File "train.py", line 213, in <module>
        main()
      File "train.py", line 208, in main
        train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
      File "train.py", line 96, in train_loop
        optimizer, ema, training_stats = p_training_step(
    jax._src.traceback_util.FilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
    
    The stack trace above excludes JAX-internal frames.
    The following is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "train.py", line 213, in <module>
        main()
      File "train.py", line 208, in main
        train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
      File "train.py", line 96, in train_loop
        optimizer, ema, training_stats = p_training_step(
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/api.py", line 1582, in f_pmapped
        out = pxla.xla_pmap(
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1453, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1385, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1456, in process
        return trace.process_map(self, fun, tracers, params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 625, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 621, in xla_pmap_impl
        return compiled_fun(*args)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1168, in execute_replicated
        out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
    RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
    
    opened by kayhan-batmanghelich 5
  • issues with libdevice.10

    issues with libdevice.10

    Hi,

    Thank you for sharing your code. It seems I have some issue that is probably caused by JAX but I cannot figure it out. I set up the softlink correctly (or maybe not?) but I see this error. As you see XLA sees my GPU but the code still looks for a library which is still there:

    [...]
       compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/util.py", line 198, in wrapper
        return cached(bool(config.x64_enabled), *args, **kwargs)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/util.py", line 191, in cached
        return f(*args, **kwargs)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/xla.py", line 291, in xla_primitive_callable
        compiled = backend_compile(backend, built_c, options)
      File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/xla.py", line 355, in backend_compile
        return backend.compile(built_c, compile_options=options)
    RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ ipython
    Python 3.8.8 (default, Feb 24 2021, 21:46:12)
    Type 'copyright', 'credits' or 'license' for more information
    IPython 7.21.0 -- An enhanced Interactive Python. Type '?' for help.
    
    In [1]: from jax.lib import xla_bridge
       ...: print(xla_bridge.get_backend().platform)
    gpu
    
    In [2]: exit
    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ ls -lt /usr/local/cuda-11.1
    lrwxrwxrwx 1 root root 39 Mar 11 15:16 /usr/local/cuda-11.1 -> /usr/lib/nvidia-cuda-toolkit/libdevice/
    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ ls -lt /usr/lib/nvidia-cuda-toolkit/libdevice/
    total 464
    -rw-r--r-- 1 root root 471124 Oct 16 13:42 libdevice.10.bc
    (jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$
    
    opened by kayhan-batmanghelich 2
Owner
Jamie Townsend
Jamie Townsend
A very tiny, very simple, and very secure file encryption tool.

Picocrypt is a very tiny (hence "Pico"), very simple, yet very secure file encryption tool. It uses the modern ChaCha20-Poly1305 cipher suite as well

Evan Su 1k Dec 30, 2022
Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! Very tiny! Stock Market Financial Technical Analysis Python library . Quant Trading automation or cryptocoin exchange

MyTT Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! to Stock Market Financial Technical Analysis Python

dev 34 Dec 27, 2022
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 2, 2023
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022