code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Overview

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

News

(03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr


This is the official PyTorch implementation of TCL

image

Requirements:

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
pip install transformers==4.8.1
pip install timm==0.4.9
conda install ruamel_yaml
pip install opencv-python
pip install --upgrade Pillow
pip install einops

Pre-training Datasets:

Downstream-task Datasets:

Json Files from Pre-training and Downstream Tasks:

  • refer to Download in ALBEF
  • you need to change the image path in json files according to your downloaded images

Pre-trained checkpoint:

Pre-training:

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Pretrain.py \
--config ./configs/Pretrain.yaml \
--output_dir output/pretrain

Downstream Tasks:

Image-Text Retrieval

# zero-shot coco 
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir output/pretrain_e30_Retrieval_coco_zeroshot \
--checkpoint output/pretrain/checkpoint_29.pth \
--evaluate

# fine-tune flickr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/pretrain_e30_Retrieval_flickr \
--checkpoint output/pretrain/checkpoint_29.pth

# fine-tune coco
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir output/pretrain_e30_Retrieval_coco \
--checkpoint output/pretrain/checkpoint_29.pth

# zero-shot flickr 
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/pretrain_e30_Retrieval_flickr_zeroshot \
--checkpoint output/pretrain_e30_Retrieval_coco/checkpoint_best.pth \
--evaluate

VQA

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/pretrain_e30_vqa \
--checkpoint output/pretrain/checkpoint_29.pth

Visual Entailment

python -m torch.distributed.launch --nproc_per_node=8 \
--use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/pretrain_e30_VE \
--checkpoint output/pretrain/checkpoint_29.pth

NLVR2

# pre-train nlvr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/pretrain_e30_NLVR_pretrain \
--checkpoint output/pretrain/checkpoint_29.pth

# fine-tune nlvr
python -m torch.distributed.launch --nproc_per_node=8 \
--use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/pretrain_e30_NLVR \
--checkpoint output/pretrain_e30_NLVR_pretrain/checkpoint_00.pth

Citation:

@article{yang2022vision,
  title={Vision-Language Pre-Training with Triple Contrastive Learning},
  author={Yang, Jinyu and Duan, Jiali and Tran, Son and Xu, Yi and Chanda, Sampath and Chen, Liqun and Zeng, Belinda and Chilimbi, Trishul and Huang, Junzhou},
  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
  year={2022}
}

Our code is largely borrowed from ALBEF

Comments
  • loss is nan when pretaining on my own dataset

    loss is nan when pretaining on my own dataset

    hi, thanks for your excellent work firstly. when i train my own chinese dataset (so i change the bert-base-uncased to bert-base-chinese), loss becomes nan after several iterations. i have tried to decrease the lr and add grad_clip, but the problem still exists. image here is my training config: image

    can you give me some suggestion? thanks in advance.

    opened by liangzimei 13
  • No module named 'refTools'

    No module named 'refTools'

    Hi I am trying to reproduce the results from inside a docker container. After installing the dependencies I hit the following error:

      File "Pretrain.py", line 30, in <module>
        from dataset import create_dataset, create_sampler, create_loader
      File "/workspace/dataset/__init__.py", line 6, in <module>
        from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset
      File "/workspace/dataset/caption_dataset.py", line 12, in <module>
        from dataset.utils import pre_caption
      File "/workspace/dataset/utils.py", line 45, in <module>
        from refTools.evaluation.refEvaluation import RefEvaluation
    ModuleNotFoundError: No module named 'refTools'
    

    every time when running:

    python -m torch.distributed.launch --nproc_per_node=8 \ --use_env Pretrain.py \ --config ./configs/Pretrain.yaml \ --output_dir output/pretrain

    I have tried pip3 install reftools but it does not solve the issue. Have you run into this issue before?

    opened by PeterDykas 6
  • Roughly how much time does the zero-shot retrieval evaluation take?

    Roughly how much time does the zero-shot retrieval evaluation take?

    Hi,

    Many thanks for the great work and releasing the code. How much time does the inference of zero-shot retrieval take? On 2 V100s, it takes around 2 hours for me on MS-COCO. Is that normal or something is potentially wrong with my setup?

    Many thanks in advance.

    opened by yash0307 5
  • About the loss_distill

    About the loss_distill

    Hi, thank you for the excellent work and the release of the code!

    I am a little confused about the approach to calculating loss_distill in line 1429 of xbert.py as shown in

                      loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1)*soft_labels,dim=-1)
    

    I think the size of both prediction_scores and soft_labels would be (batch_size, seq_len, vocab_size). And F.softmax is used in the last dimension for soft_labels in line 237 of model_pretrain.py, as shown in

                      mlm_output = self.text_encoder(input_ids, 
                                                     attention_mask = text.attention_mask,
                                                     encoder_hidden_states = image_embeds,
                                                     encoder_attention_mask = image_atts,      
                                                     return_dict = True,
                                                     labels = labels,   
                                                     soft_labels = F.softmax(logits_m,dim=-1),
                                                     alpha = alpha
                                                    )
    

    Why is F.log_softmax used in the second dimension (dimension of seq_len) for prediction_scores?

    opened by sushizixin 3
  • Question about Data augmentation for MoCo.

    Question about Data augmentation for MoCo.

    Dear author, I feel thankful of your great masterpiece, and I really appreciate about your work these days.

    Reading your paper with comprehending your code,

    I got in my mind about data augmentation.

    I found that you give data augmentation on Image, then, did you do same thing on text modalities?

    If it is right, then can you check out that line in code?

    Thx.

    opened by celestialxevermore 2
  • About the XBert

    About the XBert

    Hi thanks for this wonderful work. I am confused about the CrossAttention Module, In the code of XBERT,when layer_num>=6, the text_encoder will turn into cross attention, however it will do self-attention on text_embeds and then do cross-attention between the text_embeds and image_embeds. I am confused why do self-attention on text_embeds and then do the cross-attention. Can it do self-attention on image_embeds first and then do cross-attention? or Can it only do the cross-attention? Please help me solve this problem when you are convenient. Thank you again!

    opened by zsmmsz99 2
  • json file problem

    json file problem

    whether the follow json file belong to VQA v2 ?

    train_file: ['../data/json_down/vqa_train.json', '../data/json_down/vqa_val.json',
    '../data/json_down/vg.json']

    test_file: ['../data/json_down/vqa_test.json'] answer_list: '../data/json_down/answer_list.json'

    vqa_root: '../data/VQA/Images/mscoco/' #train2014/ vg_root: '../data/VG/VG_100K/' #image/

    opened by fengmingfeng 1
  • Question about the MLM masking

    Question about the MLM masking

    Hi,

    10% of the time, we replace masked input tokens with random word

        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    

    Here is the code you use to replace a token with a random word. Is it correct to use 0.5 as the parameter here? Thank you for your answer.

    opened by longkukuhi 1
  • Can not reproduce zero-shot retrieval performance

    Can not reproduce zero-shot retrieval performance

    Hi, I have downloaded the pre-trained checkpoint TCL_4m.pth you provided and prepared Flickr30k.

    I run the following command:

    python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --use_env Retrieval.py \
    --config ./configs/Retrieval_flickr.yaml \
    --output_dir output/pretrain_e30_Retrieval_flickr_zeroshot \
    --checkpoint ./data/TCL_4M.pth \
    --evaluate
    

    Here are the results I get:

    {"val_txt_r1": 87.96844181459566, "val_txt_r5": 98.12623274161736, "val_txt_r10": 99.40828402366864, "val_txt_r_mean": 95.16765285996057, "val_img_r1": 72.07100591715977, "val_img_r5": 90.55226824457594, "val_img_r10": 94.5759368836292, "val_img_r_mean": 85.73307034845497, "val_r_mean": 90.45036160420777, "test_txt_r1": 89.4, "test_txt_r5": 98.6, "test_txt_r10": 99.6, "test_txt_r_mean": 95.86666666666667, "test_img_r1": 73.36, "test_img_r5": 92.16, "test_img_r10": 95.52, "test_img_r_mean": 87.01333333333332, "test_r_mean": 91.44, "epoch": 0}
    

    According to the Table 2 in your paper, zero-shot R@1 performance on Flickr30K test set is 93.0 (text retrieval) and 79.6 (image retrieval). But what I get is test_txt_r1 = 89.4 and text_img_r1 = 73.36.

    Do I make something wrong?

    opened by yangbang18 1
  • How to obtain `T+`?

    How to obtain `T+`?

    Hi!

    Thanks for releasing the codes! Sorry to bother in this busy CVPR week, but here's one minor question:

    It's said in Sec. 3.2 that two sets of textual inputs, T and T+, are to be fed to h(.) and h_hat(.), respectively. Could you point me anywhere that how to obtain T+ exactly?

    Thanks!

    opened by juliuswang0728 1
  • About GPU Usage and Training Time

    About GPU Usage and Training Time

    Hi, thanks for your great work and code sharing.

    According to config/Pretrain.yaml, the batch_size is set to 64 (i.e., each GPU will process 64 image-text pairs) during pretraining. I would like to know how much GPU memory will be used and how much time will be taken per epoch (on 4M dataset) under this setting.

    By the way, I have read your excellent paper but can not find the supplementary materials on the web. Could you share a link to download it? Thanks a lot.

    opened by yangbang18 1
Owner
null
[CVPR'21 Oral] Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning

Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning [CVPR'21, Oral] By Zhicheng Huang*, Zhaoyang Zeng*, Yupan H

Multimedia Research 196 Dec 13, 2022
Saeed Lotfi 28 Dec 12, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 1, 2022
Code repo for EMNLP21 paper "Zero-Shot Information Extraction as a Unified Text-to-Triple Translation"

Zero-Shot Information Extraction as a Unified Text-to-Triple Translation Source code repo for paper Zero-Shot Information Extraction as a Unified Text

cgraywang 88 Dec 31, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 8, 2023
Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

DeCLIP Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm. Our paper is available in arxiv Updates ** Ou

Sense-GVT 470 Dec 30, 2022
CLIP (Contrastive Language–Image Pre-training) trained on Indonesian data

CLIP-Indonesian CLIP (Radford et al., 2021) is a multimodal model that can connect images and text by training a vision encoder and a text encoder joi

Galuh 17 Mar 10, 2022
[CVPR 2022 Oral] Versatile Multi-Modal Pre-Training for Human-Centric Perception

Versatile Multi-Modal Pre-Training for Human-Centric Perception Fangzhou Hong1  Liang Pan1  Zhongang Cai1,2,3  Ziwei Liu1* 1S-Lab, Nanyang Technologic

Fangzhou Hong 96 Jan 3, 2023
[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models Codes for this paper The Lottery Tickets Hypo

VITA 59 Dec 28, 2022
Code of our paper "Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning"

CCOP Code of our paper Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning Requirement Install OpenSelfSup Install Detectron2

Chenhongyi Yang 21 Dec 13, 2022
[CVPR 2022] "The Principle of Diversity: Training Stronger Vision Transformers Calls for Reducing All Levels of Redundancy" by Tianlong Chen, Zhenyu Zhang, Yu Cheng, Ahmed Awadallah, Zhangyang Wang

The Principle of Diversity: Training Stronger Vision Transformers Calls for Reducing All Levels of Redundancy Codes for this paper: [CVPR 2022] The Pr

VITA 16 Nov 26, 2022
Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

THUNLP 75 Nov 2, 2022
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

null 248 Dec 4, 2022
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

null 250 Jan 8, 2023
X-VLM: Multi-Grained Vision Language Pre-Training

X-VLM: learning multi-grained vision language alignments Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts. Yan Zeng, Xi

Yan Zeng 286 Dec 23, 2022
[CVPR 2022 Oral] Rethinking Minimal Sufficient Representation in Contrastive Learning

Rethinking Minimal Sufficient Representation in Contrastive Learning PyTorch implementation of Rethinking Minimal Sufficient Representation in Contras

null 36 Nov 23, 2022
Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localization and Semantic Segmentation (CVPR 2022)

CCAM (Unsupervised) Code repository for our paper "CCAM: Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localizati

Computer Vision Insitute, SZU 113 Dec 27, 2022
Official implementation for "QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation" (CVPR 2022)

QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation (CVPR2022) https://arxiv.org/abs/2203.08483 Unpaired image-to-image (I2I

Xueqi Hu 50 Dec 16, 2022