Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

Overview

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation (Salesforce Research)

This is the official PyTorch implementation of the ALBEF paper [Blog]. This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.

Requirements:

  • pytorch 1.8.0
  • transformers 4.8.1
  • timm 0.4.9

Download:

Visualization:

We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. Here is an example visualization using the visual grounding checkpoint.

Pre-training on custom datasets:

  1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
  2. In configs/Pretrain.yaml, set the paths for the json files.
  3. Pre-train the model using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 

Image-Text Retrieval:

  1. Download MSCOCO or Flickr30k datasets from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/Retrieval_flickr \
--checkpoint [Pretrained checkpoint]

VQA:

  1. Download VQA v2 dataset and Visual Genome dataset from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/VQA.yaml, set the paths for the json files and the image paths.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/vqa \
--checkpoint [Pretrained checkpoint]
  1. Evaluate the result using the official evaluation server.

Visual Entailment:

  1. Download SNLI-VE dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/VE.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/VE \
--checkpoint [Pretrained checkpoint]

Visual Grounding on RefCOCO+:

  1. Download MSCOCO dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/Grounding.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir output/RefCOCO \
--gradcam_mode itm \ 
--block_num 8 \
--checkpoint [Pretrained checkpoint]

NLVR2:

NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/NLVR_pretrain \
--checkpoint [Pretrained checkpoint]

We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.

  1. Download NLVR2 dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/NLVR.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/NLVR \
--checkpoint [TA pretrained checkpoint]

Citation

If you find this code to be useful for your research, please consider citing.

@article{ALBEF,
      title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, 
      author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
      year={2021},
      journal={arXiv preprint arXiv:2107.07651},
}
Comments
  • Problem of NLVR_pretrain.yaml file

    Problem of NLVR_pretrain.yaml file

    Problem 1: Hello, I found "train_file" have "train_file", such as: train_file: ['/export/home/project/VL/dataset/caption/coco_karpathy_train.json', '/export/home/project/VL/dataset/caption/vg_caption.json',
    '/export/home/project/VL/dataset/pretrain_caption/conceptual_caption_train.json', '/export/home/project/VL/dataset/pretrain_caption/conceptual_caption_val.json', '/export/home/project/VL/dataset/pretrain_caption/sbu_caption.json'
    ] Could you please provide the above JSON file?

    Problem 2: When I fine-tune NLVR2 task, I need to first run NLVR_pretrain.yaml file, and then run NLVR.yaml?

    opened by haoshuai714 17
  • how to test the model?

    how to test the model?

    hi, I just want to test the effect of the model. here is my test code:

    python Retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir output/Retrieval_flickr --checkpoint model_file/ALBEF.pth --evaluate True

    I have changed the relevant configuration files.

    Am I right to test like this?

    thx!

    opened by CQUTWangHong 13
  • Pretrain phase problem

    Pretrain phase problem

    I have a problem at pretrain phase, when the program is run by half, such as: WARNING:torch.distributed.elastic.agent.server.api:Received 1 death signal, shutting down workers WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6593 closing signal SIGHUP WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6593 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6594 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6595 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6596 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6597 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6598 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6599 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 6600 closing signal SIGTERM Traceback (most recent call last): File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run result = self._invoke_run(role) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 843, in _invoke_run time.sleep(monitor_interval) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 60, in _terminate_process_handler raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) torch.distributed.elastic.multiprocessing.api.SignalException: Process 6523 got signal: 1

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/usr/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/launch.py", line 193, in main() File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/launch.py", line 189, in main launch(args) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/launch.py", line 174, in launch run(args) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/run.py", line 713, in run )(*cmd_args) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/launcher/api.py", line 131, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/launcher/api.py", line 252, in launch_agent result = agent.run() File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper result = f(*args, **kwargs) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 716, in run self._shutdown(e.sigval) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 190, in _shutdown self._pcontext.close(death_sig) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 330, in close self._close(death_sig=death_sig, timeout=timeout) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 709, in _close if handler.proc.poll() is None: File "/usr/lib/python3.6/subprocess.py", line 875, in poll return self._internal_poll() File "/usr/lib/python3.6/subprocess.py", line 1403, in _internal_poll pid, sts = _waitpid(self.pid, _WNOHANG) File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 60, in _terminate_process_handler raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) torch.distributed.elastic.multiprocessing.api.SignalException: Process 6523 got signal: 1

    Have you ever had a similar problem?

    opened by haoshuai714 11
  • Result worse than it in the paper

    Result worse than it in the paper

    I used the 4M datasets to pretrain the albef_model, and finetune on the Image Retrieval task. For Flickr dataset, I get the TR R@1 82.14 but the result in your paper is 94.3. And other results are also worse than the results in your paper. I did not change any settings in the code. Do you know where the problems are ? Are there some tricks in the training?

    opened by DandelionYoungL 10
  • Cannot load image from CC3M

    Cannot load image from CC3M

    Get the following error: PIL.UnidentifiedImageError: cannot identify image file '/home/ubuntu/data/CC3M/DownloadConceptualCaptions/validation/10481_3355970027'

    The error is generated by this code in caption_dataset.py: image = Image.open(ann['image']).convert('RGB')

    BTW, I can only download 2.4M images from CC3M/training, how did you download 2.95M images? Thanks.

    opened by viyjy 8
  • question about ITM loss

    question about ITM loss

    Hi,

    Thanks for the great work. After reading the code for calculating ITM loss, I have a question below: ec5abdbd-e78c-41e5-9ccf-334ec1d4a0cf

    The itm labels for positive and negative samples are in a "fixed" order instead of being shuffled. I'm wondering whether the order be an issue for the ITM loss to work correctly? In some other VLP models such as ViLT, the ITM loss is calculated based on an shuffled pos-neg batches, which is detailed at https://github.com/dandelin/ViLT/blob/762fd3975c180db6fc88f577cf39549983fa373a/vilt/modules/objectives.py#L207

    Thanks in advance for your kind reply.

    opened by Qiulin-W 7
  • test in VQA dataset

    test in VQA dataset

    when I test in VQA task, I run : python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py
    --config ./configs/VQA.yaml
    --output_dir output/vqa
    --checkpoint ./ALBEF.pth But, when I run one epoch, I get the result flood is empty? How to get the result? when all epoch training end?

    opened by haoshuai714 7
  • pretraining datasets json files

    pretraining datasets json files

    Hi @LiJunnan1992,

    Congrats on your great work, and thanks for releasing the code!! To help reproduce the pretraining experiments, could you release the dataset json files for the pretraining datasets as well? Thanks!

    Best, Jie

    opened by jayleicn 7
  • Image-Text Retrieval Task, ITC score for ranking

    Image-Text Retrieval Task, ITC score for ranking

    I saw the original setting use the ITM score s_{itm} for ranking, but it has more calculations. Is it ok that we only use feature similarity score s_{itc} for ranking during inference?

    opened by yxoh 6
  • obout some json files

    obout some json files

    Hi, thanks for the excellent work. I would like to know how to generate these json file refcoco+_train.json, refcoco+_val.json, refcoco+_train.json, refcoco+_test.json in data.tar.gz. How to get those json files for refcoco and Refcocog datasets?

    opened by TungWg 4
  • Questions about Visual Grounding checkpoint and visualization

    Questions about Visual Grounding checkpoint and visualization

    Thank you a lot for your outstanding work! I'm having problems with Visual Grounding task:

    1. How did you get refcoco.pth? During Fine-tuning according to this provided procedure, a 3.3G checkpoint_best.pth file, as large as the pretrained model, is generated; However the checkpoint: refcoco.pth you've given is only 800M, Could you explain how you managed to shrink the model size? I tried distill :True and distill :False in config file, making no difference to the final size. python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \ --config ./configs/Grounding.yaml \ --output_dir output/RefCOCO \ --gradcam_mode itm \ --block_num 8 \ --checkpoint [Pretrained checkpoint, size 3.3G]

    2. How to evaluate refcoco.pth? Setting distill: False in config file does not work for me; python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \ --config ./configs/Grounding.yaml \ --output_dir output/RefCOCO_albefpth \ --gradcam_mode itm \ --block_num 8 \ --evaluate \ --checkpoint refcoco.pth Drops following KeyError problem: Traceback (most recent call last): File "Grounding.py", line 295, in <module> main(args, config) File "Grounding.py", line 187, in main state_dict = checkpoint['model'] KeyError: 'model'

    3. How to visualize the 3.3G checkpoint_best.pth file generated by fine-tuning? During Fine-tuning, the [val, test_A, test_B] metrics data printed out seems fine. However, the visualization.ipynb only works for refcoco.pth, but not works for the 3.3G checkpoint_best.pth generated by fine-tuning, the heat map is totally mess, not as expected. There seems a gap between checkpoint_best.pth and refcoco.pth.

    opened by zzzzzigzag 4
  • change english text_encoder to other language?

    change english text_encoder to other language?

    Hello author, thx to the great work! i want to use ALBEF to train another language-image multi model, i am a little confused about the finetune procedure.

    Here's my options below:

    1. load your repo's pth file, and iterates the parmeters.
    2. load parameters from Bert model: bert-base-chinese to ALBEF model which tensor name contains text_encoder to pretrained
    3. freeze the parameters in ALBEF model which tensor name contains visual_encoder.

    code like this below: `

    tokenizer = BertTokenizer.from_pretrained(args.text_encoder) #load chinese bert pretrained model
    model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer)
    model_dict = model.state_dict()
    # load parameters in your ckpt file, but leave out tensors which name contains text_encoder
    temp = {}
    pretrained_dict = torch.load(args.checkpoint, map_location='cpu')['model']
    for k, v in pretrained_dict.items():
        if k.find("text_encoder") == -1 and model_dict[k].shape==v.shape:  
            temp[k] = v
    
    # replace parameters in text_encoder and freeze visual_encoder
    temp_update = {}
    for k, v in model_dict.items():
        if k in temp.keys():
            if k.find("visual_encoder") != -1:
                temp[k].requires_grad = False
            temp_update[k] = temp[k]
        else:
            temp_update[k] = v
    model_dict.update(temp_update)
    model.load_state_dict(model_dict)
    

    `

    finally i found bad recall score in flicker-cn dataset, could you give me some advise?

    opened by jammyWolf 0
  • NLVR2 Pretrain

    NLVR2 Pretrain

    Hi expert: As we know, there is a text alignment pretrain task for NLVR2, which is a one pass, three class task. I have read the code, but I don't get why it can be writen like this, can you teach me and explain? image

    opened by lonestar234028 1
  • TypeError: add_code_sample_docstrings() got an unexpected keyword argument

    TypeError: add_code_sample_docstrings() got an unexpected keyword argument "tokenizer_class"

    Hello, when run this code, I found a problem in "xbert.py", TypeError: add_code_sample_docstrings() got an unexpected keyword argument "tokenizer_class", could you please tell me why?

    opened by dongxinfeng1 2
  • support other visual grounding datasets?

    support other visual grounding datasets?

    Hey, you conduct Visual Grounding experiment on RefCOCO+. Have you tried on other datasets such as RefCOCO or RefCOCOg? If I am going to do this, how can I get the data? Since in your release, only json file of RefCOCO+ is provided. Are these json file generated by yourself? Or are they downloaded from somewhere else? (I just find the data form of your ALBEF VG is not the same as TransVG.) Thank you very much. Looking forward to your reply.

    opened by PaulTHong 6
  • About VQA annotations

    About VQA annotations

    Hello, thanks for your excellent work. I'm reproducing the results in the repo. I found that the vqa_train annotation files differ from the original VQAv2 annotations. There are some answers in vqa_train that I can't find in both VQAv2 or VQAv1 annotations. Are there any data augmentation or am I missing something? An example: what is written on the bus ['buddy holly', 'buddy holly and crickets'] The two answers don't either exist in answer pools nor in the annotation files.

    opened by simplelifetime 1
Owner
Salesforce
A variety of vendor agnostic projects which power Salesforce
Salesforce
TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

DataFree A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation" Authors: Gongfa

ZJU-VIPA 47 Jan 9, 2023
Fuse radar and camera for detection

SAF-FCOS: Spatial Attention Fusion for Obstacle Detection using MmWave Radar and Vision Sensor This project hosts the code for implementing the SAF-FC

ChangShuo 18 Jan 1, 2023
Python code to fuse multiple RGB-D images into a TSDF voxel volume.

Volumetric TSDF Fusion of RGB-D Images in Python This is a lightweight python script that fuses multiple registered color and depth images into a proj

Andy Zeng 845 Jan 3, 2023
Gapmm2: gapped alignment using minimap2 (align transcripts to genome)

gapmm2: gapped alignment using minimap2 This tool is a wrapper for minimap2 to r

Jon Palmer 2 Jan 27, 2022
Deep learning algorithms for muon momentum estimation in the CMS Trigger System

Deep learning algorithms for muon momentum estimation in the CMS Trigger System The Compact Muon Solenoid (CMS) is a general-purpose detector at the L

anuragB 2 Oct 6, 2021
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

jemmy li 121 Sep 26, 2022
auto-tuning momentum SGD optimizer

YellowFin YellowFin is an auto-tuning optimizer based on momentum SGD which requires no manual specification of learning rate and momentum. It measure

Jian Zhang 288 Nov 19, 2022
Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021)

EMI-FGSM This repository contains code to reproduce results from the paper: Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021) Xiaosen Wa

John Hopcroft Lab at HUST 10 Sep 26, 2022
[CVPR'21 Oral] Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning

Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning [CVPR'21, Oral] By Zhicheng Huang*, Zhaoyang Zeng*, Yupan H

Multimedia Research 196 Dec 13, 2022
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Wang jiahao 3 Oct 31, 2022
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

IIGROUP 6 Sep 21, 2022
[CVPR2021] Look before you leap: learning landmark features for one-stage visual grounding.

LBYL-Net This repo implements paper Look Before You Leap: Learning Landmark Features For One-Stage Visual Grounding CVPR 2021. Getting Started Prerequ

SVIP Lab 45 Dec 12, 2022
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

AI2 79 Dec 23, 2022
Alex Pashevich 62 Dec 24, 2022
ICS 4u HD project, start before-wards. A curtain shooting game using python.

Touhou-Star-Salvation HDCH ICS 4u HD project, start before-wards. A curtain shooting game using python and pygame. By Jason Li For arts and gameplay,

null 15 Dec 22, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Eff video representation - Efficient video representation through neural fields

Neural Residual Flow Fields for Efficient Video Representations 1. Download MPI

null 41 Jan 6, 2023