Learning Better Visual Dialog Agents with Pretrained Visual-Linguistic Representation
Or READ-UP: Referring Expression Agent Dialog with Unified Pretraining.
This repo includes the training/testing code for our paper Learning Better Visual Dialog Agents with Pretrained Visual-Linguistic Representation that has been accepted by CVPR 2021.
Please cite the following paper if you use the code in this repository:
@inproceedings{tu2021learning,
title={Learning Better Visual Dialog Agents with Pretrained Visual-Linguistic Representation},
author={Tu, Tao and Ping, Qing and Thattai, Govindarajan and Tur, Gokhan and Natarajan, Prem},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={5622--5631},
year={2021}
}
Repository Setup
Environment
The following environment is recommended:
Instance storage: > 800 GB
pytorch 1.4.0
cuda 10.0
Set up virtual environment and install pytorch:
$ conda create -n read_up python=3.6
$ conda activate read_up
$ git clone https://github.com/amazon-research/read-up.git
# [IMPORTANT] pytorch 1.4.0 have no issue for parallel training
$ conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch
Install dependencies:
# Install general dependencies
$ sudo apt-get install build-essential libcap-dev
$ pip install -r requirement.txt
# Install vqa-maskrcnn-benchmark (for feature extraction only)
$ git clone https://gitlab.com/vedanuj/vqa-maskrcnn-benchmark.git
$ cd vqa-maskrcnn-benchmark
$ python setup.py build develop
Install Apex for distributed training
# Apex is used for both `faster-rcnn feature extraction` & `distributed training`
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Dataset
Meta-data
Download the GuessWhat?! dataset:
$ wget https://florian-strub.com/guesswhat.train.jsonl.gz -P data/
$ wget https://florian-strub.com//guesswhat.valid.jsonl.gz -P data/
$ wget https://florian-strub.com//guesswhat.test.jsonl.gz -P data/
Prepare dict.json:
- Set up repo as instructed in https://github.com/GuessWhatGame/guesswhat
- Generate the dict.json file:
$ python src/guesswhat/preprocess_data/create_dictionary.py -data_dir data -dict_file dict.json -min_occ 3
- Copy dict.json file to read-up repo:
$ cd read-up
$ mkdir tf-pretrained-model
$ cp guesswhat/data/dict.json read-up/tf-pretrained-model/
Dataset for Oracle models
1. Dataset for baseline Oracle + Faster-RCNN visual features.
Under vqa-maskrcnn-benchmark/data/
, download RCNN model and COCO images:
# download RCNN model
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/detectron_model.pth
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/detectron_config.yaml
# download COCO data
$ wget http://images.cocodataset.org/zips/train2014.zip
$ wget http://images.cocodataset.org/zips/val2014.zip
$ wget http://images.cocodataset.org/zips/test2014.zip
$ unzip -j train2014.zip
$ unzip -j valid2014.zip
$ unzip -j test2014.zip
Copy the guesswhat.train/valid/test.jsonl
to vqa-maskrcnn-benchmark/data/
. Unzip the COCO images into a folder image_dir/COCO_2014/images/
, and prepare a npy
file for feature extraction later.
$ python bin/prepare_extract_gt_features_gw.py \
--src vqa-maskrcnn-benchmark/data/guesswhat.train.jsonl \
--img-dir vqa-maskrcnn-benchmark/image_dir/COCO_2014/images/ \
--out vqa-maskrcnn-benchmark/image_dir/COCO_2014/npy_files/guesswhat.train.npy
Repeat the same process for val
and test
. The generated file looks like the following:
{
{
'file_name': 'name_of_image_file',
'file_path': '<path_to_image_file_on_your_disk>',
'bbox': array([
[ x1, y1, width1, height1],
[ x2, y2, width2, height2],
...
]),
'num_box': 2
},
....
}
Extract features from the ground-truth bounding boxes generated before:
$ python bin/extract_features_from_gt.py \
--model_file vqa-maskrcnn-benchmark/data/detectron_model.pth \
--config_file vqa-maskrcnn-benchmark/data/detectron_config.yaml \
--imdb_gt_file vqa-maskrcnn-benchmark/image_dir/COCO_2014/npy_files/guesswhat.train.npy \
--output_folder data/rcnn/from_gt_gw_xyxy_scale/train
Repeat this process for val
and test
data.
2. Dataset for our Oracle model.
Download the pretrained VilBERT model (both vanilla and 12-in-1 have similar performance in our experiments).
# download vanilla pretrained model
$ cd vilbert-pretrained-model
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/pretrained_model.bin
# download 12-in-1 pretrained model
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/multi_task_model.bin
Download the features for COCO:
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/datasets/coco/features_100/COCO_trainval_resnext152_faster_rcnn_genome.lmdb/data.mdb && mv data.mdb COCO_trainval_resnext152_faster_rcnn_genome.lmdb/
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/datasets/coco/features_100/COCO_test_resnext152_faster_rcnn_genome.lmdb/data.mdb && mv data.mdb COCO_test_resnext152_faster_rcnn_genome.lmdb/
Dataset for Q-Gen models
1. Dataset for baseline Q-Gen model [1]
$ wget www.florian-strub.com/github/ft_vgg_img.zip
$ unzip ft_vgg_img.zip -d img/
2. Dataset for VDST Q-Gen model [2]
$ python bin/extract_features.py \
--model_file vqa-maskrcnn-benchmark/data/detectron_model.pth \
--config_file vqa-maskrcnn-benchmark/data/detectron_config.yaml \
--image_dir vqa-maskrcnn-benchmark/image_dir/COCO_2014/images/ \
--output_folder data/rcnn/from_rcnn/ \
--batch_size 8
Dataset for Guesser models
1. Dataset for baseline Guesser model[1]
$ cd data/vilbert-multi-task
$ wget https://dl.fbaipublicfiles.com/vilbert-multi-task/datasets.tar.gz
$ tar -I pigz -xvf datasets.tar.gz datasets/guesswhat/
Model Training & Evaluation
Oracle
To train our Oracle model:
$ python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=1 \
--node_rank=0 \
main.py \
--command train-oracle-vilbert \
--config config_files/oracle_vilbert.yaml \
--n-jobs 8
To evaluate our Oracle model:
$ python main.py \
--command test-oracle-vilbert \
--config config_files/oracle_vilbert.yaml \
--load ckpt/oracle_vilbert-sd0/epoch-3.pth
This repo also implements other Oracle models:
- Baseline Oracle model [1]
- Baseline Oracle model + Faster-RCNN visual features (our ablation model)
To train and evaluate this model, run the main.py with corresponding config file and command.
Guesser
To train our Guesser model:
$ python main.py \
--command train-guesser-vilbert \
--config config_files/guesser_vilbert.yaml \
--n-jobs 8
To evaluate our Guesser model:
$ python main.py \
--command test-guesser-vilbert \
--config config_files/guesser_vilbert.yaml \
--n-jobs 8 \
--load ckpt/guesser_vilbert-sd0/best.pth
This repo also implements other Guesser models:
- Baseline Guesser model [1]
To train and evaluate this model, run the main.py with corresponding config file and command.
Q-Gen
To train our Q-Gen model:
# Distributed training
$ python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=1 \
--node_rank=0 \
main.py \
--command train-qgen-vilbert \
--config config_files/qgen_vilbert.yaml \
--n-jobs 8
# Non-distributed training
$ python main.py \
--command train-qgen-vilbert \
--config config_files/qgen_vilbert.yaml \
--n-jobs 8
To evalaute our Q-Gen model:
$ python main.py \
--command test-self-play-all-vilbert \
--config config_files/self_play_all_vilbert.yaml \
--n-jobs 8
This repo also implements other Q-Gen models:
- Baseline Q-Gen model [1]
- VDST Q-Gen model [2]
To train and evaluate these models, run the main.py with corresponding config file and command.
References
[1] Strub, F., De Vries, H., Mary, J., Piot, B., Courvile, A., & Pietquin, O. (2017, August). End-to-end optimization of goal-driven and visually grounded dialogue systems. In Proceedings of the 26th International Joint Conference on Artificial Intelligence (pp. 2765-2771).
[2] Pang, W., & Wang, X. (2020, April). Visual dialogue state tracking for question generation. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 34, No. 07, pp. 11831-11838).