SEED
Implementations for the ICLR-2021 paper: SEED: Self-supervised Distillation For Visual Representation.
@Article{fang2020seed,
author = {Fang, Zhiyuan and Wang, Jianfeng and Wang, Lijuan and Zhang, Lei and Yang, Yezhou and Liu, Zicheng},
title = {SEED: Self-supervised Distillation For Visual Representation},
journal = {International Conference on Learning Representations},
year = {2021},
}
Introduction
This paper is concerned with self-supervised learning for small models. The problem is motivated by our empirical studies that while the widely used contrastive self-supervised learning method has shown great progress on large model training, it does not work well for small models. To address this problem, we propose a new learning paradigm, named SElf-SupErvised Distillation (SEED), where we leverage a larger network (as Teacher) to transfer its representational knowledge into a smaller architecture (as Student) in a self-supervised fashion. Instead of directly learning from unlabeled data, we train a student encoder to mimic the similarity score distribution inferred by a teacher over a set of instances. We show that SEED dramatically boosts the performance of small networks on downstream tasks. Compared with self-supervised baselines, SEED improves the top-1 accuracy from 42.2% to 67.6% on EfficientNet-B0 and from 36.3% to 68.2% on MobileNetV3-Large on the ImageNet-1k dataset. SEED improves the ResNet-50 from 67.4% to 74.3% from the previous MoCo-V2 baseline.
Preperation
Note: This repository does not contain the ImageNet dataset building, please refer to MoCo-V2 for the enviromental setting & dataset preparation. Be careful if you use FaceBook's ImageNet dataset implementation as the provided dataloader here is to handle TSV ImageNet source.
Self-Supervised Distillation Training
SWAV's 400_ep ResNet-50 model as Teacher architecture for a Student EfficientNet-b1 model with multi-view strategies. Place the pre-trained checkpoint in ./output
directory. Remember to change the parameter name in the checkpoint as some module provided by SimCLR, MoCo-V2 and SWAV are inconsistent with regular PyTorch implementations. Here we provide the pre-trained SWAV/MoCo-V2/SimCLR Pre-trained checkpoints, but all credits belong to them.
Teacher Arch. | SSL Method | Teacher SSL-epochs | Link |
---|---|---|---|
ResNet-50 | MoCo-V1 | 200 | URL |
ResNet-50 | SimCLR | 200 | URL |
ResNet-50 | MoCo-V2 | 200 | URL |
ResNet-50 | MoCo-V2 | 800 | URL |
ResNet-50 | SWAV | 800 | URL |
ResNet-101 | MoCo-V2 | 200 | URL |
ResNet-152 | MoCo-V2 | 200 | URL |
ResNet-152 | MoCo-V2 | 800 | URL |
ResNet-50X2 | SWAV | 400 | URL |
ResNet-50X4 | SWAV | 400 | URL |
ResNet-50X5 | SWAV | 400 | URL |
To conduct the training one GPU on single Node using Distributed Training:
python -m torch.distributed.launch --nproc_per_node=1 main_small-patch.py \
-a efficientnet_b1 \
-k resnet50 \
--teacher_ssl swav \
--distill ./output/swav_400ep_pretrain.pth.tar \
--lr 0.03 \
--batch-size 16 \
--temp 0.2 \
--workers 4
--output ./output \
--data [your TSV imagenet-folder with train folders]
Conduct linear evaluations on ImageNet-val split:
python -m torch.distributed.launch --nproc_per_node=1 main_lincls.py \
-a efficientnet_b0 \
--lr 30 \
--batch-size 32 \
--output ./output \
[your TSV imagenet-folder with val folders]
Checkpoints by SEED
Here we provide some pre-trained checkpoints after distillation by SEED. Note: the 800 epcohs one are trained with small-view strategies and have better performances.
Student-Arch. | Teacher-Arch. | Teacher SSL | Student SEED-epochs | Link |
---|---|---|---|---|
ResNet-18 | ResNet-50 | MoCo-V2 | 200 | URL |
ResNet-18 | ResNet-50W2 | SWAV | 400 | URL |
MobileV3-Large | ResNet-50 | MoCo-V2 | 200 | URL |
EfficientNet-B0 | ResNet-50W4 | SWAV | 400 | URL |
EfficientNet-B0 | ResNet-50W2 | SWAV | 800 | URL |
EfficientNet-B1 | ResNet-50 | SWAV | 200 | URL |
EfficientNet-B1 | ResNet-152 | SWAV | 200 | URL |
ResNet-50 | ResNet-50W4 | SWAV | 400 | URL |
Glance of the Performances
ImageNet-1k test accuracy (%) using KNN and linear classification for multiple students and MoCov2 pre-trained deeper teacher architectures. ✗ denotes MoCo-V2 self-supervised learning baselines before distillation. * indicates using a deeper teacher encoder pre-trained by SWAV, where additional small-patches are also utilized during distillation and trained for 800 epochs. K denotes Top-1 accuracy using KNN. T-1 and T-5 denote Top-1 and Top-5 accuracy using linear evaluation. First column shows Top-1 Acc. of Teacher network. First row shows the supervised performances of student networks.
Acknowledge
This implementation is largely originated from: MoCo-V2. Thanks SWAV and SimCLR for the pre-trained SSL checkpoints.
This work is done jointly with ASU-APG lab and Microsoft Azure-Florence Group. Thanks my collaborators.
License
SEED is released under the MIT license.