TWIST: Self-Supervised Learning by Estimating Twin Class Distributions
Codes and pretrained models for TWIST:
@article{wang2021self,
title={Self-Supervised Learning by Estimating Twin Class Distributions},
author={Wang, Feng and Kong, Tao and Zhang, Rufeng and Liu, Huaping and Li, Hang},
journal={arXiv preprint arXiv:2110.07402},
year={2021}
}
TWIST is a novel self-supervised representation learning method by classifying large-scale unlabeled datasets in an end-to-end way. We employ a siamese network terminated by a softmax operation to produce twin class distributions of two augmented images. Without supervision, we enforce the class distributions of different augmentations to be consistent. In the meantime, we regularize the class distributions to make them sharp and diverse. TWIST can naturally avoid the trivial solutions without specific designs such as asymmetric network, stop-gradient operation, or momentum encoder.
Models and Results
Main Models for Representation Learning
arch | params | epochs | linear | download | ||||
---|---|---|---|---|---|---|---|---|
Model with multi-crop and self-labeling | ||||||||
ResNet-50 | 24M | 850 | 75.5% | backbone only | full ckpt | args | log | eval logs |
ResNet-50w2 | 94M | 250 | 77.7% | backbone only | full ckpt | args | log | eval logs |
DeiT-S | 21M | 300 | 75.6% | backbone only | full ckpt | args | log | eval logs |
ViT-B | 86M | 300 | 77.3% | backbone only | full ckpt | args | log | eval logs |
Model without multi-crop and self-labeling | ||||||||
ResNet-50 | 24M | 800 | 72.6% | backbone only | full ckpt | args | log | eval logs |
Model for unsupervised classification
arch | params | epochs | NMI | AMI | ARI | ACC | download | |||
---|---|---|---|---|---|---|---|---|---|---|
ResNet-50 | 24M | 800 | 74.4 | 57.7 | 30.1 | 40.5 | backbone only | full ckpt | args | log |
Top-3 predictions for unsupervised classification
Semi-Supervised Results
arch | 1% labels | 10% labels | 100% labels |
---|---|---|---|
resnet-50 | 61.5% | 71.7% | 78.4% |
resnet-50w2 | 67.2% | 75.3% | 80.3% |
Detection Results
Task | AP all | AP 50 | AP 75 |
---|---|---|---|
VOC07+12 detection | 58.1 | 84.2 | 65.4 |
COCO detection | 41.9 | 62.6 | 45.7 |
COCO instance segmentation | 37.9 | 59.7 | 40.6 |
Single-node Training
ResNet-50 (requires 8 GPUs, Top-1 Linear 72.6%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--aug barlow \
--batch-size 256 \
--dim 32768 \
--epochs 800
Multi-node Training
ResNet-50 (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.5%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT}
ResNet-50w2 (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.7%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'resnet50w2' \
--batch-size 60 \
--bunch-size 240 \
--epochs 250 \
--mme_epochs 200
DeiT-S (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.6%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'vit_s' \
--batch-size 128 \
--bunch-size 256 \
--clip_norm 3.0 \
--epochs 300 \
--mme_epochs 300 \
--lam1 -0.6 \
--lam2 1.0 \
--local_crops_number 6 \
--lr 0.0005 \
--momentum_start 0.996 \
--momentum_end 1.0 \
--optim admw \
--use_momentum_encoder 1 \
--weight_decay 0.06 \
--weight_decay_end 0.06
ViT-B (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.3%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'vit_b' \
--batch-size 64 \
--bunch-size 256 \
--clip_norm 3.0 \
--epochs 300 \
--mme_epochs 300 \
--lam1 -0.6 \
--lam2 1.0 \
--local_crops_number 6 \
--lr 0.00075 \
--momentum_start 0.996 \
--momentum_end 1.0 \
--optim admw \
--use_momentum_encoder 1 \
--weight_decay 0.06 \
--weight_decay_end 0.06
Linear Classification
For ResNet-50
python3 evaluate.py \
${DATAPATH} \
${OUTPUT}/checkpoint.pth \
--weight-decay 0 \
--checkpoint-dir ${OUTPUT}/linear_multihead/ \
--batch-size 1024 \
--val_epoch 1 \
--lr-classifier 0.2
For DeiT-S
python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
--arch vit_s \
--pretrained_weights ${OUTPUT}/checkpoint.pth \
--lr 0.02 \
--data_path ${DATAPATH} \
--output_dir ${OUTPUT} \
For ViT-B
python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
--arch vit_b \
--pretrained_weights ${OUTPUT}/checkpoint.pth \
--lr 0.0015 \
--data_path ${DATAPATH} \
--output_dir ${OUTPUT} \
Semi-supervised Learning
Command for training semi-supervised classification
1% Percent (61.5%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.04 \
--lr-classifier 0.2 \
--train-percent 1 \
--weight-decay 0 \
--epochs 20 \
--backbone 'resnet50'
10% Percent (71.7%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.02 \
--lr-classifier 0.2 \
--train-percent 10 \
--weight-decay 0 \
--epochs 20 \
--backbone 'resnet50'
100% Percent (78.4%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.01 \
--lr-classifier 0.2 \
--train-percent 100 \
--weight-decay 0 \
--epochs 30 \
--backbone 'resnet50'
Detection
Instruction
-
Install detectron2.
-
Convert a pre-trained MoCo model to detectron2's format:
python3 detection/convert-pretrain-to-detectron2.py ${MODELPATH} ${OUTPUTPKLPATH}
-
Put dataset under "detection/datasets" directory, following the directory structure requried by detectron2.
-
Training: VOC
cd detection/ python3 train_net.py \ --config-file voc_fpn_1fc/pascal_voc_R_50_FPN_24k_infomin.yaml \ --num-gpus 8 \ MODEL.WEIGHTS ../${OUTPUTPKLPATH}
COCO
python3 train_net.py \ --config-file infomin_configs/R_50_FPN_1x_infomin.yaml \ --num-gpus 8 \ MODEL.WEIGHTS ../${OUTPUTPKLPATH}