Hi,
It seems I am still unable to run it. I tried both cfar10 and ffhq256 both stop in the middle although in different stages. I am running on a machine with two 3090 so I think at least on cifar10 it should be good enough:
CIFAR 10
(jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ python train.py --hps cifar10
time: Sat Mar 13 14:21:09 2021, type: hparam, key: adam_beta1, value: 0.90000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: adam_beta2, value: 0.90000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: axis_visualize, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: bottleneck_multiple, value: 0.25000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: conv_precision, value: default
time: Sat Mar 13 14:21:09 2021, type: hparam, key: custom_width_str, value:
time: Sat Mar 13 14:21:09 2021, type: hparam, key: data_root, value: ./
time: Sat Mar 13 14:21:09 2021, type: hparam, key: dataset, value: cifar10
time: Sat Mar 13 14:21:09 2021, type: hparam, key: dec_blocks, value: 1x1,4m1,4x2,8m4,8x5,16m8,16x10,32m16,32x21
time: Sat Mar 13 14:21:09 2021, type: hparam, key: desc, value: test
time: Sat Mar 13 14:21:09 2021, type: hparam, key: device_count, value: 2
time: Sat Mar 13 14:21:09 2021, type: hparam, key: ema_rate, value: 0.99990
time: Sat Mar 13 14:21:09 2021, type: hparam, key: enc_blocks, value: 32x11,32d2,16x6,16d2,8x6,8d2,4x3,4d4,1x3
time: Sat Mar 13 14:21:09 2021, type: hparam, key: epochs_per_eval, value: 10
time: Sat Mar 13 14:21:09 2021, type: hparam, key: grad_clip, value: 200.00000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: host_count, value: 1
time: Sat Mar 13 14:21:09 2021, type: hparam, key: host_id, value: 0
time: Sat Mar 13 14:21:09 2021, type: hparam, key: hps, value: cifar10
time: Sat Mar 13 14:21:09 2021, type: hparam, key: image_channels, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: image_size, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_ckpt, value: 25000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_images, value: 10000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_print, value: 1000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: iters_per_save, value: 10000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: log_wandb, value: False
time: Sat Mar 13 14:21:09 2021, type: hparam, key: logdir, value: ./saved_models/test/log
time: Sat Mar 13 14:21:09 2021, type: hparam, key: lr, value: 0.00020
time: Sat Mar 13 14:21:09 2021, type: hparam, key: n_batch, value: 16
time: Sat Mar 13 14:21:09 2021, type: hparam, key: no_bias_above, value: 64
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_depths_visualize, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_epochs, value: 10000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_images_visualize, value: 8
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_mixtures, value: 10
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_temperatures_visualize, value: 3
time: Sat Mar 13 14:21:09 2021, type: hparam, key: num_variables_visualize, value: 6
time: Sat Mar 13 14:21:09 2021, type: hparam, key: restore_path, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: save_dir, value: ./saved_models/test
time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed, value: 0
time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_eval, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_init, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_sample, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: seed_train, value: None
time: Sat Mar 13 14:21:09 2021, type: hparam, key: skip_threshold, value: 400.00000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: test_eval, value: False
time: Sat Mar 13 14:21:09 2021, type: hparam, key: warmup_iters, value: 100
time: Sat Mar 13 14:21:09 2021, type: hparam, key: wd, value: 0.01000
time: Sat Mar 13 14:21:09 2021, type: hparam, key: width, value: 384
time: Sat Mar 13 14:21:09 2021, type: hparam, key: zdim, value: 16
time: Sat Mar 13 14:21:09 2021, message: training model test on cifar10
time: Sat Mar 13 14:21:47 2021, total_params: 39145700, readable: 39,145,700
time: Sat Mar 13 14:30:08 2021, model: test, type: train_loss, lr: 0.00000, epoch: 0, step: 0, elbo: -48.51558, elbo_filtered: -48.51558, grad_norm: 207.70775, kl: 37.04228, kl_nans: 0, log_likelihood: -11.47330, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 498.63470
time: Sat Mar 13 14:31:20 2021, message: printing samples to ./saved_models/test/samples-0.png
time: Sat Mar 13 14:31:21 2021, model: test, type: train_loss, lr: 0.00000, epoch: 0, step: 1, elbo: -48.48836, elbo_filtered: -48.48836, grad_norm: 207.70775, kl: 36.87710, kl_nans: 0, log_likelihood: -11.61126, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.42617
time: Sat Mar 13 14:31:55 2021, message: printing samples to ./saved_models/test/samples-1.png
time: Sat Mar 13 14:31:59 2021, model: test, type: train_loss, lr: 0.00002, epoch: 0, step: 8, elbo: -42.48382, elbo_filtered: -42.48382, grad_norm: 207.70775, kl: 30.93100, kl_nans: 0, log_likelihood: -11.55281, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.77924
time: Sat Mar 13 14:32:33 2021, message: printing samples to ./saved_models/test/samples-8.png
time: Sat Mar 13 14:32:37 2021, model: test, type: train_loss, lr: 0.00003, epoch: 0, step: 16, elbo: -34.54240, elbo_filtered: -34.54240, grad_norm: 207.70775, kl: 23.11177, kl_nans: 0, log_likelihood: -11.43063, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.78281
time: Sat Mar 13 14:33:12 2021, message: printing samples to ./saved_models/test/samples-16.png
time: Sat Mar 13 14:33:19 2021, model: test, type: train_loss, lr: 0.00006, epoch: 0, step: 32, elbo: -25.47084, elbo_filtered: -25.47084, grad_norm: 207.70775, kl: 14.20382, kl_nans: 0, log_likelihood: -11.26701, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40239
time: Sat Mar 13 14:33:54 2021, message: printing samples to ./saved_models/test/samples-32.png
time: Sat Mar 13 14:34:09 2021, model: test, type: train_loss, lr: 0.00013, epoch: 0, step: 64, elbo: -18.77669, elbo_filtered: -18.77669, grad_norm: 207.70775, kl: 7.75593, kl_nans: 0, log_likelihood: -11.02076, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40384
time: Sat Mar 13 14:34:44 2021, message: printing samples to ./saved_models/test/samples-64.png
time: Sat Mar 13 14:35:13 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 128, elbo: -14.41902, elbo_filtered: -14.41902, grad_norm: 207.70775, kl: 3.98318, kl_nans: 0, log_likelihood: -10.43585, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.39954
time: Sat Mar 13 14:35:48 2021, message: printing samples to ./saved_models/test/samples-128.png
time: Sat Mar 13 14:36:47 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 256, elbo: -11.45558, elbo_filtered: -11.45558, grad_norm: 207.70775, kl: 2.06421, kl_nans: 0, log_likelihood: -9.39137, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.40935
time: Sat Mar 13 14:37:23 2021, message: printing samples to ./saved_models/test/samples-256.png
time: Sat Mar 13 14:39:21 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 512, elbo: -9.25702, elbo_filtered: -9.25702, grad_norm: 207.70775, kl: 1.13132, kl_nans: 0, log_likelihood: -8.12570, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.41006
time: Sat Mar 13 14:39:57 2021, message: printing samples to ./saved_models/test/samples-512.png
time: Sat Mar 13 14:43:43 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 1000, elbo: -7.79192, elbo_filtered: -7.79192, grad_norm: 179.25816, kl: 0.68122, kl_nans: 0, log_likelihood: -7.11070, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.45444
time: Sat Mar 13 14:43:55 2021, model: test, type: train_loss, lr: 0.00020, epoch: 0, step: 1024, elbo: -7.26243, elbo_filtered: -7.26243, grad_norm: 179.25816, kl: 0.28492, kl_nans: 0, log_likelihood: -6.97751, log_likelihood_nans: 0, skipped_updates: 0, iter_time: 0.45458
time: Sat Mar 13 14:44:30 2021, message: printing samples to ./saved_models/test/samples-1024.png
2021-03-13 14:48:14.080987: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 1 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
2021-03-13 14:48:14.081137: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 0 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
Traceback (most recent call last):
File "train.py", line 213, in <module>
main()
File "train.py", line 208, in main
train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
File "train.py", line 132, in train_loop
optimizer, ema = p_synchronize((optimizer, ema))
jax._src.traceback_util.FilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "train.py", line 213, in <module>
main()
File "train.py", line 208, in main
train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
File "train.py", line 132, in train_loop
optimizer, ema = p_synchronize((optimizer, ema))
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/api.py", line 1582, in f_pmapped
out = pxla.xla_pmap(
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1453, in bind
return call_bind(self, fun, *args, **params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1385, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1456, in process
return trace.process_map(self, fun, tracers, params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 625, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 621, in xla_pmap_impl
return compiled_fun(*args)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1168, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
FFHQ256
(jax) kayhan@lambda-dual:~/sandbox/vdvae-jax$ python train.py --hps ffhq256
time: Sat Mar 13 13:23:59 2021, type: hparam, key: adam_beta1, value: 0.90000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: adam_beta2, value: 0.90000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: axis_visualize, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: bottleneck_multiple, value: 0.25000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: conv_precision, value: default
time: Sat Mar 13 13:23:59 2021, type: hparam, key: custom_width_str, value:
time: Sat Mar 13 13:23:59 2021, type: hparam, key: data_root, value: ./
time: Sat Mar 13 13:23:59 2021, type: hparam, key: dataset, value: ffhq_256
time: Sat Mar 13 13:23:59 2021, type: hparam, key: dec_blocks, value: 1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x21,64m32,64x13,128m64,128x7,256m128
time: Sat Mar 13 13:23:59 2021, type: hparam, key: desc, value: test
time: Sat Mar 13 13:23:59 2021, type: hparam, key: device_count, value: 2
time: Sat Mar 13 13:23:59 2021, type: hparam, key: ema_rate, value: 0.99900
time: Sat Mar 13 13:23:59 2021, type: hparam, key: enc_blocks, value: 256x3,256d2,128x8,128d2,64x12,64d2,32x17,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4
time: Sat Mar 13 13:23:59 2021, type: hparam, key: epochs_per_eval, value: 1
time: Sat Mar 13 13:23:59 2021, type: hparam, key: grad_clip, value: 130.00000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: host_count, value: 1
time: Sat Mar 13 13:23:59 2021, type: hparam, key: host_id, value: 0
time: Sat Mar 13 13:23:59 2021, type: hparam, key: hps, value: ffhq256
time: Sat Mar 13 13:23:59 2021, type: hparam, key: image_channels, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: image_size, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_ckpt, value: 25000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_images, value: 10000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_print, value: 1000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: iters_per_save, value: 10000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: log_wandb, value: False
time: Sat Mar 13 13:23:59 2021, type: hparam, key: logdir, value: ./saved_models/test/log
time: Sat Mar 13 13:23:59 2021, type: hparam, key: lr, value: 0.00015
time: Sat Mar 13 13:23:59 2021, type: hparam, key: n_batch, value: 1
time: Sat Mar 13 13:23:59 2021, type: hparam, key: no_bias_above, value: 64
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_depths_visualize, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_epochs, value: 10000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_images_visualize, value: 2
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_mixtures, value: 10
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_temperatures_visualize, value: 1
time: Sat Mar 13 13:23:59 2021, type: hparam, key: num_variables_visualize, value: 3
time: Sat Mar 13 13:23:59 2021, type: hparam, key: restore_path, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: save_dir, value: ./saved_models/test
time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed, value: 0
time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_eval, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_init, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_sample, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: seed_train, value: None
time: Sat Mar 13 13:23:59 2021, type: hparam, key: skip_threshold, value: 180.00000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: test_eval, value: False
time: Sat Mar 13 13:23:59 2021, type: hparam, key: warmup_iters, value: 100
time: Sat Mar 13 13:23:59 2021, type: hparam, key: wd, value: 0.01000
time: Sat Mar 13 13:23:59 2021, type: hparam, key: width, value: 512
time: Sat Mar 13 13:23:59 2021, type: hparam, key: zdim, value: 16
time: Sat Mar 13 13:23:59 2021, message: training model test on ffhq_256
time: Sat Mar 13 13:25:23 2021, total_params: 114874852, readable: 114,874,852
2021-03-13 13:48:17.142157: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module pmap_training_step.315621
********************************
2021-03-13 13:49:08.604007: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 1 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
2021-03-13 13:49:08.607849: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 0 failed: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage
Traceback (most recent call last):
File "train.py", line 213, in <module>
main()
File "train.py", line 208, in main
train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
File "train.py", line 96, in train_loop
optimizer, ema, training_stats = p_training_step(
jax._src.traceback_util.FilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "train.py", line 213, in <module>
main()
File "train.py", line 208, in main
train_loop(H, data_train, data_valid_or_test, preprocess_fn, optimizer,
File "train.py", line 96, in train_loop
optimizer, ema, training_stats = p_training_step(
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/api.py", line 1582, in f_pmapped
out = pxla.xla_pmap(
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1453, in bind
return call_bind(self, fun, *args, **params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1385, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 1456, in process
return trace.process_map(self, fun, tracers, params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py", line 625, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 621, in xla_pmap_impl
return compiled_fun(*args)
File "/home/kayhan/anaconda3/envs/jax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1168, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:162: NCCL operation ncclAllReduce(send_buffer, recv_buffer, buffer.element_count, datatype, reduce_op, comm, *cu_stream) failed: invalid usage: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).