Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime
Created by Prarthana Bhattacharyya.
Disclaimer: This is not an official product and is meant to be a proof-of-concept and for academic/educational use only.
This repository contains the PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime, to be presented at ICASSP-2022.
Self-supervision has shown outstanding results for natural language processing, and more recently, for image recognition. Simultaneously, vision transformers and its variants have emerged as a promising and scalable alternative to convolutions on various computer vision tasks. In this paper, we are the first to question if self-supervised vision transformers (SSL-ViTs) can be adapted to two important computer vision tasks in the low-label, high-data regime: few-shot image classification and zero-shot image retrieval. The motivation is to reduce the number of manual annotations required to train a visual embedder, and to produce generalizable, semantically meaningful and robust embeddings.
Results
- SSL-ViT + few-shot image classification:
- Qualitative analysis for base-classes chosen by supervised CNN and SSL-ViT for few-shot distribution calibration:
- SSL-ViT + zero-shot image retrieval:
Pretraining Self-Supervised ViT
- Run DINO with ViT-small network on a single node with 4 GPUs for 100 epochs with the following command.
cd dino/
python -m torch.distributed.launch --nproc_per_node=4 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
- For mini-ImageNet pretraining, we use the classes listed in:
ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_mini.txt
For tiered-ImageNet pretraining, we use the classes listed in:ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_tiered.txt
- For CUB-200, Cars-196 and SOP, we use the pretrained model from:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
Visual Representation Learning with Self-Supervised ViT for Low-Label High-Data Regime
Dataset Preparation
Please follow the instruction in FRN for few-shot image classification and RevisitDML for zero-shot image retrieval to download the datasets and put the corresponding datasets in ssl-vit-fewshot/data
and DIML/data
folder.
Training and Evaluation for few-shot image classification
- The first step is to extract features for base and novel classes using the pretrained SSL-ViT.
get_dino_miniimagenet_feats.ipynb
extracts SSL-ViT features for the base and novel classes.- Change the hyper-parameter
data_path
to use CUB or tiered-ImageNet. - The SSL-ViT checkpoints for the various datasets are provided below (Note: this has only been trained without labels). We also provide the extracted features which need to be stored in
ssl-vit-fewshot/dino_features_data/
.
arch | dataset | download | extracted-train | extracted-test |
---|---|---|---|---|
ViT-S/16 | mini-ImageNet | mini_imagenet_checkpoint.pth | train.p | test.p |
ViT-S/16 | tiered-ImageNet | tiered_imagenet_checkpoint.pth | train.p | test.p |
ViT-S/16 | CUB | cub_checkpoint.pth | train.p | test.p |
- For n-way-k-shot evaluation, we provide
miniimagenet_evaluate_dinoDC.ipynb
.
Training and Evaluation for zero-shot image retrieval
- To train the baseline CNN models, run the scripts in
DIML/scripts/baselines
. The checkpoints are saved in Training_Results folder. For example:
cd DIML/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh
- To train the supervised ViT and self-supervised ViT:
cp -r ssl-vit-retrieval/architectures/* DIML/ssl-vit-retrieval/architectures/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch vits
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch dino
- To test the models, first edit the checkpoint paths in
test_diml.py
, then run
CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200
dataset | Loss | SSL-ViT-download |
---|---|---|
CUB | Margin | cub_ssl-vit-margin.pth |
CUB | Proxy-NCA | cub_ssl-vit-proxynca.pth |
CUB | Multi-Similarity | cub_ssl-vit-ms.pth |
Cars-196 | Margin | cars_ssl-vit-margin.pth |
Cars-196 | Proxy-NCA | cars_ssl-vit-proxynca.pth |
Cars-196 | Multi-Similarity | cars_ssl-vit-ms.pth |
Acknowledgement
The code is based on: