Composed Image Retrieval using Pretrained LANguage Transformers (CIRPLANT)

Overview

CIRPLANT

This repository contains the code and pre-trained models for Composed Image Retrieval using Pretrained LANguage Transformers (CIRPLANT)

For details please see our ICCV 2021 paper - Image Retrieval on Real-life Images with Pre-trained Vision-and-Language Models.

Demo image from CIRR data

If you find this repository useful, we would appreciate it if you could give us a star.

You are currently viewing the code & model repository. For more information, see our Project homepage.

Introduction

CIRPLANT is a transformer based model that leverages rich pre-trained vision-and-language (V&L) knowledge for modifying visual features conditioned on natural language. To the best of our knowledge, this is the first attempt in repurposing a V&L pre-trained (VLP) model for composed image retrieval- a task that requires language-conditioned image feature modification.

Our intention is to extend current methods to the open-domain. Together with the release of the CIRR dataset, we hope this work can inspire further research on composed image retrieval

Installation & Dataset Preparation

Check INSTALL.md for installation instructions.

Training

To train the model and reproduce our published results on CIRR:

python trainval_oscar.py --dataset cirr --usefeat nlvr-resnet152_w_empty --max_epochs 300 --model CIRPLANT-img --model_type 'bert' --model_name_or_path data/Oscar_pretrained_models/base-vg-labels/ep_107_1192087 --task_name cirr --gpus 1 --img_feature_dim 2054 --max_img_seq_length 1 --model_type bert --do_lower_case --max_seq_length 40 --learning_rate 1e-05 --loss_type xe --seed 88 --drop_out 0.3 --weight_decay 0.05 --warmup_steps 0 --loss st --batch_size 32 --num_batches 529 --pin_memory --num_workers_per_gpu 0 --comment input_your_comments --output saved_models/cirr_rc2_iccv_release_test --log_by recall_inset_top1_correct_composition

To use pre-trained weights to reproduce results in our ICCV 2021 paper, please see DOWNLOAD.md.

Developing

To develop based on our code, we highly recommend first getting familar with Pytorch Lightning.

You can train models as we have described above, the results will be saved to a folder of your choosing.

To inspect results, we recommend using Tensorboard and load the saved events.out.tfevents file. Alternatively, you can also find all information dumped to a text file log.txt.

Pytorch Lightning automatically saves the latest checkpoint last.ckpt in the same output directory. Additionally, you can also specify a certain validation score name --log_by [...] to monitor, which enables saving of the best checkpoint.

Test-split Evaluation

We do not publish the ground truth for the test split of CIRR. Instead, we host an evaluation server, should you prefer to publish results on the test-split.

To generate .json files and upload to the test server, load a trained checkpoint and enable --testonly.

As an example, compare the following arguments with the training arguments above.

python trainval_oscar.py --dataset cirr --usefeat nlvr-resnet152_w_empty --max_epochs 300 --model CIRPLANT-img --model_type 'bert' --model_name_or_path data/Oscar_pretrained_models/base-vg-labels/ep_107_1192087 --task_name cirr --gpus 1 --img_feature_dim 2054 --max_img_seq_length 1 --model_type bert --do_lower_case --max_seq_length 40 --learning_rate 1e-05 --loss_type xe --seed 88 --drop_out 0.3 --weight_decay 0.05 --warmup_steps 0 --loss st --batch_size 32 --num_batches 529 --pin_memory --num_workers_per_gpu 0 --comment input_your_comments --output saved_models/cirr_rc2_iccv_release_test --log_by recall_inset_top1_correct_composition --check_val_every_n_epoch 1 --testonly --load_from_checkpoint $CKPT_PATH

Two .json files will be saved to the output directory, one for Recall validation, the other for Recall_Subset. Visit our test server and upload it to get results.

Citation

Please consider citing this paper if you use the code:

@article{liu2021cirr,
      title={Image Retrieval on Real-life Images with Pre-trained Vision-and-Language Models}, 
      author={Zheyuan Liu and Cristian Rodriguez-Opazo and Damien Teney and Stephen Gould},
      journal={arXiv preprint arXiv:2108.04024},
      year={2021},
}
You might also like...
Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra. What is Lightning Tran

Implementation of the Hybrid Perception Block and Dual-Pruned Self-Attention block from the ITTR paper for Image to Image Translation using Transformers
Implementation of the Hybrid Perception Block and Dual-Pruned Self-Attention block from the ITTR paper for Image to Image Translation using Transformers

ITTR - Pytorch Implementation of the Hybrid Perception Block (HPB) and Dual-Pruned Self-Attention (DPSA) block from the ITTR paper for Image to Image

🏆 The 1st Place Submission to AICity Challenge 2021 Natural Language-Based Vehicle Retrieval Track (Alibaba-UTS submission)
🏆 The 1st Place Submission to AICity Challenge 2021 Natural Language-Based Vehicle Retrieval Track (Alibaba-UTS submission)

🏆 The 1st Place Submission to AICity Challenge 2021 Natural Language-Based Vehicle Retrieval Track (Alibaba-UTS submission)

PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Cross-Covariance Image Transformer (XCiT) PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer L

T‘rex Park is a Youzan sponsored project. Offering Chinese NLP and image models pretrained from E-commerce datasets
T‘rex Park is a Youzan sponsored project. Offering Chinese NLP and image models pretrained from E-commerce datasets

T‘rex Park is a Youzan sponsored project. Offering Chinese NLP and image models pretrained from E-commerce datasets (product titles, images, comments, etc.).

Simplified diarization pipeline using some pretrained models - audio file to diarized segments in a few lines of code
Simplified diarization pipeline using some pretrained models - audio file to diarized segments in a few lines of code

simple_diarizer Simplified diarization pipeline using some pretrained models. Made to be a simple as possible to go from an input audio file to diariz

A Neural Language Style Transfer framework to transfer natural language text smoothly between fine-grained language styles like formal/casual, active/passive, and many more. Created by Prithiviraj Damodaran. Open to pull requests and other forms of collaboration. A library for finding knowledge neurons in pretrained transformer models.
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

IndoBERTweet is the first large-scale pretrained model for Indonesian Twitter. Published at EMNLP 2021 (main conference)

IndoBERTweet 🐦 🇮🇩 1. Paper Fajri Koto, Jey Han Lau, and Timothy Baldwin. IndoBERTweet: A Pretrained Language Model for Indonesian Twitter with Effe

Comments
  • Which PyTorch Lightning version to use?

    Which PyTorch Lightning version to use?

    Hello, thanks for the great work!

    I am trying to run your code, but there are issues with PyTorch Lightning (specifically the API changes very quickly).

    If run with the latest version (1.6.5 at the moment), an error is thrown because val_dataloaders is not a valid argument for Trainer.validate:

    https://github.com/Cuberick-Orion/CIRPLANT/blob/4592c979eb8638ccd0d8590a68507df26c27cb89/trainval_oscar.py#L271

    However, if I revert to 1.3.1 (specified here), then an error is thrown as it cannot import lightning from pytorch_lightning.core (see full traceback below).

    After some trial and error, I found 1.5.1 works. Perhaps this could be specified in the README?

    P.S:

      File "/private/home/sgvaze/CIRPLANT/trainval_oscar.py", line 242, in <module>
        from _trainval_base import init_main
      File "/private/home/sgvaze/CIRPLANT/_trainval_base.py", line 34, in <module>
        from pytorch_lightning.core import lightning
      File "/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/pytorch_lightning/__init__.py", line 20, in <module>
        from pytorch_lightning import metrics  # noqa: E402
      File "/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/pytorch_lightning/metrics/__init__.py", line 15, in <module>
        from pytorch_lightning.metrics.classification import (  # noqa: F401
      File "/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/__init__.py", line 14, in <module>
        from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
      File "/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 18, in <module>
        from pytorch_lightning.metrics.utils import deprecated_metrics, void
      File "/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/pytorch_lightning/metrics/utils.py", line 22, in <module>
        from torchmetrics.utilities.data import get_num_classes as _get_num_classes
    ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (/private/home/sgvaze/miniconda3/envs/cirr/lib/python3.9/site-packages/torchmetrics/utilities/data.py)```
    opened by sgvaze 2
  • I want to know about the loss function more definitely

    I want to know about the loss function more definitely

    In the paper, there is loss function like "L=log[1+exp(k(φi, ϕ−i,j) - k(φi, ϕ+i))]" for the loss to be zero, the value of k(φi, ϕ−i,j) must be small and the value of k(φi, ϕ+i) is large.

    But I think k(φi, ϕ+i) should be small because it is the l2 distance between prediction and target, and k(φi, ϕ−i,j) should be large because it is the distance between prediction and false image's feature, so the Loss function should be changed as follows. "L=log[1+exp(k(φi, ϕ+i) - k(φi, ϕ−i,j))]"

    I want to know if what I was thinking is correct.

    opened by SeolMuah 1
  • Question about a difference in the value of the author's recall and my recall.

    Question about a difference in the value of the author's recall and my recall.

    I have a question because there is a difference in the value of the author's top-k recall and my top-k recall. When validating with the checkpoint posted by the author, the result value as shown in the following picture comes out. This results in a value different from the value raised by the author(site:https://github.com/Cuberick-Orion/CIRPLANT/blob/main/DOWNLOAD.md), and the loss value is 0.3. What's the reason? 캡처

    opened by GaEunKim-study 7
  • I have three question with the code

    I have three question with the code "model/OSCAR/OSCAR_CIRPLANT.py"

    In model/OSCAR/OSCAR_CIRPLANT.py part of “def comput_soft_triplet_loss” in 211

    1. I wonder why you're reshape the size to 160.
    2. In vvv=1, I wonder why you get rid of -1 from vvv
    3. In the paper, metric learning is written as follows. the average loss was obtained from all sampled triplets. I wonder where the code for that part is.
    opened by GaEunKim-study 0
Owner
Zheyuan (David) Liu
長い夢見る心はそう 永遠で
Zheyuan (David) Liu
🚀 RocketQA, dense retrieval for information retrieval and question answering, including both Chinese and English state-of-the-art models.

In recent years, the dense retrievers based on pre-trained language models have achieved remarkable progress. To facilitate more developers using cutt

null 475 Jan 4, 2023
Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

Memorizing Transformers - Pytorch Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memori

Phil Wang 364 Jan 6, 2023
Framework for fine-tuning pretrained transformers for Named-Entity Recognition (NER) tasks

NERDA Not only is NERDA a mesmerizing muppet-like character. NERDA is also a python package, that offers a slick easy-to-use interface for fine-tuning

Ekstra Bladet 141 Dec 30, 2022
🛸 Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy

spacy-transformers: Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy This package provides spaCy components and architectures to use tr

Explosion 1.2k Jan 8, 2023
🛸 Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy

spacy-transformers: Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy This package provides spaCy components and architectures to use tr

Explosion 903 Feb 17, 2021
Code for "Finetuning Pretrained Transformers into Variational Autoencoders"

transformers-into-vaes Code for Finetuning Pretrained Transformers into Variational Autoencoders (our submission to NLP Insights Workshop 2021). Gathe

Seongmin Park 22 Nov 26, 2022
Code for CVPR 2021 paper: Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning

Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning This is the PyTorch companion code for the paper: A

Amazon 69 Jan 3, 2023
This repository contains the code for "Generating Datasets with Pretrained Language Models".

Datasets from Instructions (DINO ?? ) This repository contains the code for Generating Datasets with Pretrained Language Models. The paper introduces

Timo Schick 154 Jan 1, 2023
ProteinBERT is a universal protein language model pretrained on ~106M proteins from the UniRef90 dataset.

ProteinBERT is a universal protein language model pretrained on ~106M proteins from the UniRef90 dataset. Through its Python API, the pretrained model can be fine-tuned on any protein-related task in a matter of minutes. Based on our experiments with a wide range of benchmarks, ProteinBERT usually achieves state-of-the-art performance. ProteinBERT is built on TenforFlow/Keras.

null 241 Jan 4, 2023
BMInf (Big Model Inference) is a low-resource inference package for large-scale pretrained language models (PLMs).

BMInf (Big Model Inference) is a low-resource inference package for large-scale pretrained language models (PLMs).

OpenBMB 377 Jan 2, 2023