Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

Overview

HiT-GAN Official TensorFlow Implementation

HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "Improved Transformer for High-Resolution GANs" for more details.

This implementation is based on TensorFlow 2.x. We use tf.keras layers for building the model and use tf.data for our input pipeline. The model is trained using a custom training loop with tf.distribute on multiple TPUs/GPUs.

Environment setup

It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.

pip install -r requirements.txt

ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from the official guide.

Train on ImageNet

To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

To train the model on ImageNet with multiple GPUs, try the following command:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=False

Please set train_batch_size according to the number of GPUs for training. Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (--use_ema_model=False), so training with GPUs will lead to slight performance drop.

Evaluate on ImageNet

Run the following command to evaluate the model on GPUs:

python run.py --mode=eval --dataset=imagenet2012 \
  --eval_batch_size=128 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --latent_dim=256 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

This command runs models with 8 P100 GPUs. Please set eval_batch_size according to the number of GPUs for evaluation. Please also note that train_steps and use_ema_model should be set according to the values used for training.

CelebA-HQ

At the first time, download CelebA-HQ following tensorflow_datasets instruction from the official guide.

Train on CelebA-HQ

The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=celeb_a_hq/256 \
  --train_batch_size=256 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on CelebA-HQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=celeb_a_hq/256 \
  --eval_batch_size=128 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

FFHQ

At the first time, download the tfrecords of FFHQ from the official site and put them into $DATA_DIR.

Train on FFHQ

The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=ffhq/256 \
  --train_batch_size=256 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on FFHQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=ffhq/256 \
  --eval_batch_size=128 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

Cite

@inproceedings{zhao2021improved,
  title = {Improved Transformer for High-Resolution {GANs}},
  author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas abd Han Zhang},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}

Disclaimer

This is not an officially supported Google product.

You might also like...
PyTorch implementation for our NeurIPS 2021 Spotlight paper
PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short-Term Transformer for Online Action Detection".

Long Short-Term Transformer for Online Action Detection Introduction This is a PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short

Official implementation of Neural Bellman-Ford Networks (NeurIPS 2021)

NBFNet: Neural Bellman-Ford Networks This is the official codebase of the paper Neural Bellman-Ford Networks: A General Graph Neural Network Framework

Official Pytorch implementation for Deep Contextual Video Compression, NeurIPS 2021

Introduction Official Pytorch implementation for Deep Contextual Video Compression, NeurIPS 2021 Prerequisites Python 3.8 and conda, get Conda CUDA 11

Official implementation of NeurIPS'2021 paper TransformerFusion
Official implementation of NeurIPS'2021 paper TransformerFusion

TransformerFusion: Monocular RGB Scene Reconstruction using Transformers Project Page | Paper | Video TransformerFusion: Monocular RGB Scene Reconstru

Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy.
Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy.

Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy. Now with tensorflow 1.0 support. Evaluation usa

TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform
TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting (RVM) English | 中文 Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specific

Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences", CVPR 2021.

HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature fo

Comments
  • RuntimeError: `merge_call` called while defining a new graph or a tf.function.

    RuntimeError: `merge_call` called while defining a new graph or a tf.function.

    Thank for your greatest work! i was tring to train a hit-gan model on one singel A100, but i am not very familiar with tensorflow framework, after i run the comand :

    is there any steps i made that is wrong? Could you please offer some advice ?

    python run.py --mode=train --dataset=ffhq/256
    --train_batch_size=2 --train_steps=500000
    --image_crop_size=256 --image_crop_proportion=1.0
    --save_every_n_steps=1000
    --latent_dim=512 --generator_lr=0.00005
    --discriminator_lr=0.00005 --channel_multiplier=2
    --use_consistency_regularization=True
    --data_dir ./ffhq_tfrecords --model_dir ./output
    --use_tpu=False

    Log:

    WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape:0' shape=(2, 4, 16, 512) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:19.474379 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape:0' shape=(2, 4, 16, 512) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_2:0' shape=(2, 16, 16, 512) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:20.280509 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_2:0' shape=(2, 16, 16, 512) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_4:0' shape=(2, 16, 64, 256) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:20.650825 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_4:0' shape=(2, 16, 64, 256) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_6:0' shape=(2, 64, 64, 128) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:21.114238 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/Reshape_6:0' shape=(2, 64, 64, 128) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/position_embedding_5/add:0' shape=(2, 128, 128, 64) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:21.439471 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/position_embedding_5/add:0' shape=(2, 128, 128, 64) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/position_embedding_6/add:0' shape=(2, 256, 256, 64) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:21.661504 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/hi_t_generator/position_embedding_6/add:0' shape=(2, 256, 256, 64) dtype=float32>, <tf.Tensor 'while/hi_t_generator/sequential_1/batch_normalization/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_3:0' shape=(2, 4, 16, 512) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:27.076365 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_3:0' shape=(2, 4, 16, 512) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_5:0' shape=(2, 16, 16, 512) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:27.416526 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_5:0' shape=(2, 16, 16, 512) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_7:0' shape=(2, 16, 64, 256) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:27.912703 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_7:0' shape=(2, 16, 64, 256) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_9:0' shape=(2, 64, 64, 128) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:28.204938 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/Reshape_9:0' shape=(2, 64, 64, 128) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/position_embedding_12/add:0' shape=(2, 128, 128, 64) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:28.492619 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/position_embedding_12/add:0' shape=(2, 128, 128, 64) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/position_embedding_13/add:0' shape=(2, 256, 256, 64) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. W0228 20:50:28.661950 140426522629888 sequential.py:366] Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'while/position_embedding_13/add:0' shape=(2, 256, 256, 64) dtype=float32>, <tf.Tensor 'while/sequential_13/batch_normalization_3/FusedBatchNormV3:0' shape=(2, 1, 16, 512) dtype=float32>) Consider rewriting this model with the Functional API. INFO:tensorflow:Error reported to Coordinator: in user code:

    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 222, in _train_one_step  *
        self._update_ema_model()
    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 131, in _update_ema_model  *
        moving_averages.update_ema_variables(self.ema_generator.variables,
    File "/home/platform/sq/pro2022/hit-gan-main/utils/moving_averages.py", line 73, in update_ema_variables  *
        replica_context.merge_call(_update_all_in_cross_replica_context_fn,
    

    RuntimeError: in user code:

    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 222, in _train_one_step  *
        self._update_ema_model()
    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 131, in _update_ema_model  *
        moving_averages.update_ema_variables(self.ema_generator.variables,
    File "/home/platform/sq/pro2022/hit-gan-main/utils/moving_averages.py", line 73, in update_ema_variables  *
        replica_context.merge_call(_update_all_in_cross_replica_context_fn,
    
    RuntimeError: `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.
    

    Traceback (most recent call last): File "run.py", line 190, in app.run(main) File "/home/platform/anaconda3/envs/GPEN/lib/python3.7/site-packages/absl/app.py", line 312, in run _run_main(main, args) File "/home/platform/anaconda3/envs/GPEN/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main sys.exit(main(argv)) File "run.py", line 179, in main trainer.train() File "/home/platform/sq/pro2022/hit-gan-main/trainers/base_trainer.py", line 210, in train train_multiple_steps(iterator) File "/home/platform/anaconda3/envs/GPEN/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler raise e.with_traceback(filtered_tb) from None File "/home/platform/anaconda3/envs/GPEN/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 1129, in autograph_handler raise e.ag_error_metadata.to_exception(e) RuntimeError: in user code:

    File "/home/platform/sq/pro2022/hit-gan-main/trainers/base_trainer.py", line 201, in train_multiple_steps  *
        self.strategy.run(self._train_one_step, args=(next(iterator),))
    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 222, in _train_one_step  *
        self._update_ema_model()
    File "/home/platform/sq/pro2022/hit-gan-main/trainers/gan_trainer.py", line 131, in _update_ema_model  *
        moving_averages.update_ema_variables(self.ema_generator.variables,
    File "/home/platform/sq/pro2022/hit-gan-main/utils/moving_averages.py", line 73, in update_ema_variables  *
        replica_context.merge_call(_update_all_in_cross_replica_context_fn,
    
    RuntimeError: `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.
    
    opened by sq2001 2
  • How do latent embeddings work?

    How do latent embeddings work?

    Your work is wonderful! However, I still feel uncertain about some parts.

    I see you take latent embedding, which is generated by embedding the input sampled code, as the key and value for the cross attention part, how can this work? In my view, the self-attention or a compressed self-attention will offer more help, since the latent embedding only contains low-level information.

    It would be my pleasure to hear your opinion! Thanks~

    opened by ZahraFan 1
Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Ng Kam Woh 71 Dec 22, 2022
Official implementation of NeurIPS 2021 paper "Contextual Similarity Aggregation with Self-attention for Visual Re-ranking"

CSA: Contextual Similarity Aggregation with Self-attention for Visual Re-ranking PyTorch training code for CSA (Contextual Similarity Aggregation). We

Hui Wu 19 Oct 21, 2022
Pytorch implementation of RED-SDS (NeurIPS 2021).

Recurrent Explicit Duration Switching Dynamical Systems (RED-SDS) This repository contains a reference implementation of RED-SDS, a non-linear state s

Abdul Fatir 10 Dec 2, 2022
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 1, 2023
The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL), NeurIPS-2021

Directed Graph Contrastive Learning The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL). In this paper, we present the first con

Tong Zekun 28 Jan 8, 2023
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

null 76 Jan 3, 2023
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 7, 2022
Official implementation of Generalized Data Weighting via Class-level Gradient Manipulation (NeurIPS 2021).

Generalized Data Weighting via Class-level Gradient Manipulation This repository is the official implementation of Generalized Data Weighting via Clas

null 9 Nov 3, 2021
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

machen 11 Nov 27, 2022