Cross-Modal Contrastive Learning for Text-to-Image Generation

Overview

Cross-Modal Contrastive Learning for Text-to-Image Generation

This repository hosts the open source JAX implementation of XMC-GAN.

Setup instructions

Environment

Set up virtualenv, and install required libraries:

virtualenv venv
source venv/bin/activate

Add the XMC-GAN library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/xmcgan/root/

JAX Installation

Note: Please follow the official JAX instructions for installing a GPU compatible version of JAX.

Other Dependencies

After installing JAX, install the remaining dependencies with:

pip install -r requirements.txt

Preprocess COCO-2014

To create the training and eval data, first start a directory. By default, the training scripts expect to save results in data/ in the base directory.

mkdir data/

The TFRecords required for training and validation on COCO-2014 can be created by running a preprocessing script over the TFDS coco_captions dataset:

python preprocess_data.py

This may take a while to complete, as it runs a pretrained BERT model over the captions and stores the embeddings. With a GPU, it runs in about 2.5 hours for train, and 1 hour for validation. Once it is done, the train and validation tfrecords files will be saved in the data/ directory. The train files require around 58G of disk space, and the validation requires 29G.

Note: If you run into an error related to TensorFlow gfile, one workaround is to edit site-packages/bert/tokenization.py and change tf.gfile.GFile to tf.io.gfile.GFile. For more details, refer to the following link.

If you run into a tensorflow.python.framework.errors_impl.ResourceExhaustedError about having too many open files, you may have to increase the machine's open file limits. To do so, open the limit configuration file for editing:

vi /etc/security/limits.conf

and append the following lines to the end of the file:

*         hard    nofile      500000
*         soft    nofile      500000
root      hard    nofile      500000
root      soft    nofile      500000

You may have to adjust the limit values depending on your machine. You will need to logout and login to your machine for these values to take effect.

Download Pretrained ResNet

To train XMC-GAN, we need a network pretrained on ImageNet to extract features. For our purposes, we train a ResNet-50 network for this. To download the weights, run:

gsutil cp gs://gresearch/xmcgan/resnet_pretrained.npy data/

If you would like to pretrain your own network on ImageNet, please refer to the official Flax ImageNet example.

Training

Start a training run, by first editing train.sh to specify an appropriate work directory. By default, the script assumes that 8 GPUs are available, and runs training on the first 7 GPUs, while test.sh assumes testing will run on the last GPU. After configuring the training job, start an experiment by running it on bash:

mkdir exp
bash train.sh exp_name &> train.txt

Checkpoints and Tensorboard logs will be saved in /path/to/exp/exp_name. By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights. To reproduce the full results on 256px images in our paper, the full model needs to be run using a 32-core Pod slice of Google Cloud TPU v3 devices.

Evaluation

To run an evaluation job, update test.sh with the correct settings used in the training script. Then, execute

bash test.sh exp_name &> eval.txt

to start an evaluation job. All checkpoints in workdir will be evaluated for FID and Inception Score. If you can spare the GPUs, you can also run train.sh and test.sh in parallel, which will continuously evaluate new checkpoints saved into the work directory. Scores will be written to Tensorboard and output to eval.txt.

Tensorboard

To start a Tensorboard for monitoring training progress, run:

tensorboard --logdir /path/to/exp/exp_name

Citation

If you find this work useful, please consider citing:

@inproceedings{zhang2021cross,
  title={Cross-Modal Contrastive Learning for Text-to-Image Generation},
  author={Zhang, Han and Koh, Jing Yu and Baldridge, Jason and Lee, Honglak and Yang, Yinfei},
  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Disclaimer

Not an official Google product.

Comments
  • Is training steps relevant with batch size?

    Is training steps relevant with batch size?

    Hi,

    Thx for your open-source code !

    I find that training steps per epoch is not relevant with batch size. In line 343

    steps_per_epoch = num_train_examples // (jax.local_device_count() * config.d_step_per_g_step)

    maybe it shoud be

    steps_per_epoch = num_train_examples // (jax.local_device_count() * config.d_step_per_g_step * config.batch_size)

    opened by fortunechen 2
  • About the package libml

    About the package libml

    Excuse me. I'm trying to run it. But now I encountered a problem. libml is not in the requirements.txt file. At the same time, when I installed it using pip, reports an error <cannot import name 'input_pipeline' from 'libml' >. so what's the problem? What should I do?

    opened by euyy 1
  • Question about the sentence embedding.

    Question about the sentence embedding.

    Excuse me, in your code, the sentence embedding is calculated by averaging the word embedding in the word_num dim. While, in BERT, the encoding corresponding to the '[CLS]' token can represent the whole sentence. Why not use this as sentence embedding? Which one performs better?

    opened by SUNJIMENG 1
  • The implementation details in the code.

    The implementation details in the code.

    Excuse me, line 216 in xmc_net.py, "x = dense_fn(self.gf_dim * 16 * 4 * 4)(z)", which input the noise "z" into the linear layer. However, Table 7 in supp shows that the input is the concatenation of noise and reshaped condition. I wonder which one is right.

    opened by SUNJIMENG 1
  • How to calculate the IoU scores without ground truth bounding box input?

    How to calculate the IoU scores without ground truth bounding box input?

    I found out that you used the official code to compute the SOA scores. But in[OP-GAN](https://github.com/ppjh8263/semantic-object-accuracy-for-generative-text-to-image-synthesis/tree/1d07bf250aedf9e1b0c55505eb76c49d60ce0055/SOA. ") it is described as : "In order to calculate the IoU scores you need to save the "ground truth" information, i.e. the bounding boxes you give your model as input, so we can compare them with the bounding boxes from the detection network." And there is no bounding box input in your model(xmcgan). Can I get the details of your SOA score calculation? Thank you.

    opened by Toneyaya 0
  • Error when training on multi-GPU

    Error when training on multi-GPU

    I got the following error message when training on multiple GPUs...

    Traceback (most recent call last): File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/u7801832/xmcenv/xmcgan_image_generation/xmcgan/main.py", line 70, in app.run(main) File "/home/u7801832/xmcenv/lib/python3.8/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/home/u7801832/xmcenv/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "/home/u7801832/xmcenv/xmcgan_image_generation/xmcgan/main.py", line 62, in main train_utils.train(FLAGS.config, FLAGS.workdir) File "/home/u7801832/xmcenv/xmcgan_image_generation/xmcgan/train_utils.py", line 421, in train batch = jax.tree_map(np.asarray, next(train_iter)) File "/home/u7801832/xmcenv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 761, in next return self._next_internal() File "/home/u7801832/xmcenv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 744, in _next_internal ret = gen_dataset_ops.iterator_get_next( File "/home/u7801832/xmcenv/lib/python3.8/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2728, in iterator_get_next _ops.raise_from_not_ok_status(e, name) File "/home/u7801832/xmcenv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 6897, in raise_from_not_ok_status six.raise_from(core._status_to_exception(e.code, message), None) File "", line 3, in raise_from tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [7,16,17,768] but got [1,112,17,768]. [Op:IteratorGetNext]

    The training script is as follows...

    #!/bin/bash CONFIG="xmcgan/configs/coco_xmc.py" EXP_NAME=$1 WORKDIR="/work/u7801832/data2/" # CHANGEME

    CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6" python -m xmcgan.main
    --config="$CONFIG"
    --mode="train"
    --workdir="$WORKDIR" \

    Please help.

    Thanks

    opened by justimchung 0
  • Error while turning self.return_text switch ON

    Error while turning self.return_text switch ON

    Dear Team,

    I want to save generated images with caption. But when I'm switching self.return_text to True. It fails at below line for conversion saying array of caption is not a valid JAX type. Please help.

    https://github.com/jayati-naik/xmcgan_image_generation/blame/b03e3b5cc4885f415bc1a4b8998260412f5ae803/xmcgan/utils/eval_metrics.py#L148

    opened by jayati-naik 0
  • Reproducing the results (256ptx) mentioned in the paper on GPUs

    Reproducing the results (256ptx) mentioned in the paper on GPUs

    Is it not possible to reproduce the results mentioned in the paper for 256 ptx images using only GPUs? It seems to me that model is not converging for 256 ptx images on GPUs.

    But for 128ptx it works fine and reaches the FIDs mentioned in the repo.

    opened by Chumsy0725 1
  • error: train.sh: line 24: 45523 Segmentation fault      (core dumped)

    error: train.sh: line 24: 45523 Segmentation fault (core dumped)

    error: I1216 05:03:33.303731 140638140207296 utils.py:31] Checkpoint.restore_or_initialize() ... I1216 05:03:33.304307 140638140207296 checkpoint.py:301] No checkpoint specified. Restore the latest checkpoint. I1216 05:03:33.304460 140638140207296 utils.py:31] MultihostCheckpoint.get_latest_checkpoint_to_restore_from() ... I1216 05:03:33.312287 140638140207296 checkpoint.py:430] Checked checkpoint base_directories: ['path/to/exp/exp_name/checkpoints-0'] - common_numbers={1} - exclusive_numbers=set() I1216 05:03:33.312516 140638140207296 utils.py:41] MultihostCheckpoint.get_latest_checkpoint_to_restore_from() finished after 0.01s. I1216 05:03:33.312650 140638140207296 checkpoint.py:307] Restoring checkpoint: path/to/exp/exp_name/checkpoints-0/ckpt-1 2021-12-16 05:03:33.316385: W ./tensorflow/core/framework/dataset.h:550] Failed precondition: StatelessRandomGetKeyCounter is stateful. I1216 05:03:45.659061 140638140207296 checkpoint.py:312] Restored save_counter=1 restored_checkpoint=path/to/exp/exp_name/checkpoints-0/ckpt-1 I1216 05:03:45.659443 140638140207296 utils.py:41] Checkpoint.restore_or_initialize() finished after 12.36s. I1216 05:03:47.525738 140590360545024 logging_writer.py:56] Hyperparameters: {'architecture': 'xmc_net', 'batch_norm_group_size': -1, 'batch_size': 8, 'beta1': 0.5, 'beta2': 0.999, 'checkpoint_every_steps': 5000, 'coco_version': '2014', 'cond_size': 16, 'd_lr': 0.0004, 'd_spectral_norm': True, 'd_step_per_g_step': 14, 'data_dir': 'data/', 'dataset': 'mscoco', 'df_dim': 96, 'dtype': 'bfloat16', 'eval_avg_num': 3, 'eval_batch_size': 4, 'eval_every_steps': 1000, 'eval_num': 30000, 'g_lr': 0.0001, 'g_spectral_norm': False, 'gamma_for_g': 15, 'gf_dim': 96, 'image_contrastive': True, 'image_size': 128, 'log_loss_every_steps': 1000, 'model_name': 'xmc', 'num_epochs': 500, 'num_train_steps': -1, 'polyak_decay': 0.999, 'pretrained_image_contrastive': True, 'return_filename': False, 'return_text': False, 'seed': 42, 'sentence_contrastive': True, 'show_num': 64, 'shuffle_buffer_size': 1000, 'train_shuffle': True, 'trial': 0, 'word_contrastive': True, 'z_dim': 128} I1216 05:03:47.528530 140638140207296 train_utils.py:404] Starting training loop at step 1. /root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/profiler.py:166: UserWarning: StepTraceContext has been renamed to StepTraceAnnotation. This alias will eventually be removed; please update your code. warnings.warn( Fatal Python error: Segmentation fault

    Thread 0x00007fddbdffb700 (most recent call first): File "/root/yes/envs/py39/lib/python3.9/concurrent/futures/thread.py", line 75 in _worker File "/root/yes/envs/py39/lib/python3.9/threading.py", line 910 in run File "/root/yes/envs/py39/lib/python3.9/threading.py", line 973 in _bootstrap_inner File "/root/yes/envs/py39/lib/python3.9/threading.py", line 930 in _bootstrap

    Thread 0x00007fddbe7fc700 (most recent call first): File "/root/yes/envs/py39/lib/python3.9/concurrent/futures/thread.py", line 75 in _worker File "/root/yes/envs/py39/lib/python3.9/threading.py", line 910 in run File "/root/yes/envs/py39/lib/python3.9/threading.py", line 973 in _bootstrap_inner File "/root/yes/envs/py39/lib/python3.9/threading.py", line 930 in _bootstrap

    Current thread 0x00007fe8de6390c0 (most recent call first): File "/root/yes/envs/py39/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 1955 in shape File "<array_function internals>", line 5 in shape File "/root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/api.py", line 1307 in File "/root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/api.py", line 1307 in _mapped_axis_size File "/root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/api.py", line 1633 in f_pmapped File "/root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/api.py", line 1725 in f_pmapped File "/root/yes/envs/py39/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162 in reraise_with_filtered_traceback File "/xmc_gan/xmcgan/train_utils.py", line 424 in train File "/xmc_gan/xmcgan/main.py", line 62 in main File "/root/yes/envs/py39/lib/python3.9/site-packages/absl/app.py", line 251 in _run_main File "/root/yes/envs/py39/lib/python3.9/site-packages/absl/app.py", line 303 in run File "/xmc_gan/xmcgan/main.py", line 70 in File "/root/yes/envs/py39/lib/python3.9/runpy.py", line 87 in _run_code File "/root/yes/envs/py39/lib/python3.9/runpy.py", line 197 in _run_module_as_main train.sh: line 24: 45523 Segmentation fault (core dumped) CUDA_VISIBLE_DEVICES="0,1,2,3" python -m xmcgan.main --config="$CONFIG" --mode="train" --workdir="$WORKDIR"

    details: config.batch_size = 8 config.d_step_per_g_step = 14

    Have you ever come across this mistake?

    opened by zzy994491827 1
  • FailedPreconditionError

    FailedPreconditionError

    how to solve 'tensorflow.python.framework.errors_impl.FailedPreconditionError: StatelessRandomGetKeyCounter is stateful. [Op:SerializeIterator]'?

    image

    Below is my checkpoint directory: image

    opened by hyeonjinXZ 2
Owner
Google Research
Google Research
PyTorch code for the paper "Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval".

Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval (M2HSE) PyTorch code fo

Xinlei-Pei 6 Dec 23, 2022
CVPR 2021 Official Pytorch Code for UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training

UC2 UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training Mingyang Zhou, Luowei Zhou, Shuohang Wang, Yu Cheng, Linjie Li, Zhou Yu,

Mingyang Zhou 28 Dec 30, 2022
Code for Referring Image Segmentation via Cross-Modal Progressive Comprehension, CVPR2020.

CMPC-Refseg Code of our CVPR 2020 paper Referring Image Segmentation via Cross-Modal Progressive Comprehension. Shaofei Huang*, Tianrui Hui*, Si Liu,

spyflying 55 Dec 1, 2022
Image-generation-baseline - MUGE Text To Image Generation Baseline

MUGE Text To Image Generation Baseline Requirements and Installation More detail

null 23 Oct 17, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”

RGBT Crowd Counting Lingbo Liu, Jiaqi Chen, Hefeng Wu, Guanbin Li, Chenglong Li, Liang Lin. "Cross-Modal Collaborative Representation Learning and a L

null 37 Dec 8, 2022
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
ISBI 2022: Cross-level Contrastive Learning and Consistency Constraint for Semi-supervised Medical Image.

Cross-level Contrastive Learning and Consistency Constraint for Semi-supervised Medical Image Introduction This repository contains the PyTorch implem

null 25 Nov 9, 2022
《Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis》(2021)

Image2Reverb Image2Reverb is an end-to-end neural network that generates plausible audio impulse responses from single images of acoustic environments

Nikhil Singh 48 Nov 27, 2022
Cross-modal Deep Face Normals with Deactivable Skip Connections

Cross-modal Deep Face Normals with Deactivable Skip Connections Victoria Fernández Abrevaya*, Adnane Boukhayma*, Philip H. S. Torr, Edmond Boyer (*Equ

null 72 Nov 27, 2022
Probabilistic Cross-Modal Embedding (PCME) CVPR 2021

Probabilistic Cross-Modal Embedding (PCME) CVPR 2021 Official Pytorch implementation of PCME | Paper Sanghyuk Chun1 Seong Joon Oh1 Rafael Sampaio de R

NAVER AI 87 Dec 21, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

null 43 Nov 21, 2022
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

null 910 Dec 28, 2022
ROSITA: Enhancing Vision-and-Language Semantic Alignments via Cross- and Intra-modal Knowledge Integration

ROSITA News & Updates (24/08/2021) Release the demo to perform fine-grained semantic alignments using the pretrained ROSITA model. (15/08/2021) Releas

Vision and Language Group@ MIL 48 Dec 23, 2022
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

null 9 Jan 12, 2022
Saeed Lotfi 28 Dec 12, 2022
Cross Quality LFW: A database for Analyzing Cross-Resolution Image Face Recognition in Unconstrained Environments

Cross-Quality Labeled Faces in the Wild (XQLFW) Here, we release the database, evaluation protocol and code for the following paper: Cross Quality LFW

Martin Knoche 10 Dec 12, 2022
A pytorch-based deep learning framework for multi-modal 2D/3D medical image segmentation

A 3D multi-modal medical image segmentation library in PyTorch We strongly believe in open and reproducible deep learning research. Our goal is to imp

Adaloglou Nikolas 1.2k Dec 27, 2022
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

Utkarsh Ojha 251 Dec 11, 2022