SiT: Self-supervised vIsion Transformer

Related tags

Deep Learning SiT
Overview

SiT: Self-supervised vIsion Transformer

This repository contains the official PyTorch self-supervised pretraining, finetuning, and evaluation codes for SiT (Self-supervised image Transformer).

The training strategy is adopted from Deit

Usage

  • Create an environment

conda create -n SiT python=3.8

  • Activate the environment and install the necessary packages

conda activate SiT

conda install pytorch torchvision torchaudio cudatoolkit=11.0 -c pytorch

pip install -r requirements.txt

Self-supervised pre-training

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 72 --epochs 501 --min-lr 5e-6 --lr 1e-3 --training-mode 'SSL' --data-set 'STL10' --output 'checkpoints/SSL/STL10' --validate-every 10

Finetuning

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 120 --epochs 501 --min-lr 5e-6 --training-mode 'finetune' --data-set 'STL10' --finetune 'checkpoints/SSL/STL10/checkpoint.pth' --output 'checkpoints/finetune/STL10' --validate-every 10

Linear Evaluation

Linear projection Head

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 120 --epochs 501 --lr 1e-3 --weight-decay 5e-4 --min-lr 5e-6 --training-mode 'finetune' --data-set 'STL10' --finetune 'checkpoints/SSL/STL10/checkpoint.pth' --output 'checkpoints/finetune/STL10_LE' --validate-every 10 --SiT_LinearEvaluation 1

2-layer MLP projection Head

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 120 --epochs 501 --lr 1e-3 --weight-decay 5e-4 --min-lr 5e-6 --training-mode 'finetune' --data-set 'STL10' --finetune 'checkpoints/SSL/STL10/checkpoint.pth' --output 'checkpoints/finetune/STL10_LE_hidden' --validate-every 10 --SiT_LinearEvaluation 1 --representation-size 1024

Note: assign the --dataset_location parameter to the location of the downloaded dataset

If you use this code for a paper, please cite:

@article{atito2021sit,

  title={SiT: Self-supervised vIsion Transformer},

  author={Atito, Sara and Awais, Muhammad and Kittler, Josef},

  journal={arXiv preprint arXiv:2104.03602},

  year={2021}

}

License

This repository is released under the GNU General Public License.

Comments
  • --resume issue...!

    --resume issue...!

    When I train again using "--resume", the Nan value appears in the model both SSL and fine-tuning. I checked the image entering input, but there was no problem.

    The following will stop and warning. "Loss is nan, stopping training"

    opened by babbu3682 9
  • ZeroDivisionError: float division by zero

    ZeroDivisionError: float division by zero

    Hi Thanks for the nice work. I tried to trained the network on my own dataset. However, I got the issue below.

    Traceback (most recent call last): File "/home/Project/SiT/main.py", line 397, in main(args) File "/home/Project/SiT/main.py", line 343, in main args.clip_grad, model_ema, mixup_fn) File "/home/Project/SiT/engine.py", line 125, in train_SSL for imgs1, rots1, imgs2, rots2 in metric_logger.log_every(data_loader, print_freq, header): File "/home/Project/SiT/utils.py", line 164, in log_every header, total_time_str, total_time / len(iterable))) ZeroDivisionError: float division by zero

    It seems like the iterable is alway 0 somehow, and I checked my data loader, there are thousands of data available.

    opened by tianyu0207 5
  • lack finetune and linprobe code

    lack finetune and linprobe code

    Hi,

    Thanks for sharing this wonderful project. When I run the following commands to finetune the project, I find that some code for finetuning and linprobing is lacking, which lead following commands cannot run at all.

    python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 120 --epochs 501 --min-lr 5e-6 --training-mode 'finetune' --data-set 'STL10' --finetune 'checkpoints/SSL/STL10/checkpoint.pth' --output 'checkpoints/finetune/STL10' --validate-every 10
    

    Also, many function snippets in utils.py are missed, i.e., utils.restart_from_checkpoint, utils.fix_random_seeds(args.seed), also utils.get_sha().

    Would you please share it?

    Thank you.

    Best, Vera

    opened by verazuo 4
  • How to see the image reconstruction task results

    How to see the image reconstruction task results

    The usage example shows how to finetune the classifier head of the model in the command line, but I'm not sure how to get the reconstructed image from this output. Can you please provide a code sample for image reconstruction? Which part of the model output can be used to visually represent inference results like in the diagram?

    opened by E-Loba 4
  • visualize results

    visualize results

    Hi, I fine-tuned the model on my custom dataset for object detection and now I want to visualize the images and detected bounding boxes. Any idea how to do that?

    opened by mbirkhez 4
  • Random Erase is erroneously not being used during SSL training

    Random Erase is erroneously not being used during SSL training

    Why is Random Erase not being used during SSL training? It seems that it is erroneously being turned off in main.py. See where args.reprob and args.recount are set to 0.

    https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/main.py#L182-L186

    opened by mattroos 3
  • An Error has occurred in self-supervised pre-training

    An Error has occurred in self-supervised pre-training

    @Sara-Ahmed Thank you for sharing your wonderful achievements!

    When I ran self-supervised pre-training as described, the following subprocess CalledProcessError was raised. Can you please help me how to solve this problem?

    Typed command python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 72 --epochs 501 --min-lr 5e-6 --lr 1e-3 --training-mode 'SSL' --data-set 'STL10' --output 'checkpoints/SSL/STL10' --validate-every 10

    Errors encountered subprocess.CalledProcessError: Command '['/usr/bin/python', '-u', 'main.py', '--batch-size', '72', '--epochs', '501', '--min-lr', '5e-6', '--lr', '1e-3', '--training-mode', 'SSL', '--data-set', 'STL10', '--output', 'checkpoints/SSL/STL10', '--validate-every', '10']' returned non-zero exit status 2.

    opened by mtakamat 2
  • Data augmentation step before applying rotation

    Data augmentation step before applying rotation

    Hi Sara,

    in your paper, you write: "We found that the network struggles to distinguish between the rotated image and the rotation of the flipped image as two different classes. Instead, we included the horizontal flipping to the data augmentation step before applying rotation, and hence, the network is trained to classify the image and the flipped image to the same class."

    In: https://github.com/Sara-Ahmed/SiT/blob/1767b9146f77883b101ac790ff24b1220ce8c2dd/datasets/datasets_utils.py#L31-L47

    You do it the other way around, or am I mistaken? You first take the same batch, apply random rotation independently, and then apply the standard augmentation routine. So, in that case, the original image and the flipped one may have two different rotational classes, or not?

    Thanks in advance for clarification, Best, Julian

    opened by julian-carpenter 2
  • Passing token logits to the loss

    Passing token logits to the loss

    First of all, I want to thank you for making the code available. It is well written and easily understandable.

    I have a question about the rotational token and the tokens used for reconstructing the original image. I'm am not an expert on using PyTorch as I have always used tensorflow, so please forgive me if I ask stupid things.

    As far as I can see, you're defining all the heads for the SSL loss here: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/vision_transformer_SiT.py#L193-L206 And as far as I understand it, there is no final activation on these heads; They return logits, am I correct?

    In the train_SSL routine, you then pass these logits to your criterion routine: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/engine.py#L146-L148

    Which, if the train_mode is SSL, is the MTL_loss routine. There the logits are passed directly into their respective loss functions: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/losses.py#L32-L51

    Am I missing something? Especially in the reconstruction case, I think this cannot work, as you have normalized original images and unnormalized reconstructed image-logits and calculate the l1 loss between them. The contrastive loss should be fine, as you normalize the logits in the loss function. However, the CE loss for the rotational token is also calculated without prior activation, and I wonder why.

    Could you point me to the error in my thinking? Thanks

    opened by julian-carpenter 2
  • utils.py is outdated

    utils.py is outdated

    I am having trouble running the code, mostly from the error messages that the "utils.py" is missing several functions.

    AttributeError: module 'utils' has no attribute 'get_params_groups'

    Thanks!

    opened by hongjuny 1
  • Questions about updated code

    Questions about updated code

    Hello @Sara-Ahmed,

    Thank you for posting the PyTorch implementation of SiT! It seems like the current commit is unable to run, and I see in https://github.com/Sara-Ahmed/SiT/issues/28 you mention you will update the github soon. I have tried running the code from commit 1aacd6adcd39b71efc903d16b4e9095b97dda76f and have nearly gotten the pre-training working. However, I can't find a version of torchvision that works with my GPU (GeForce RTX 3060 sm_86) and the SiT code from this commit. I have to update timm to version 0.4.12 to avoid this error which seems to break functionality with your code.

    Do you think this will be fixed in the next release of SiT code on github? Also, you could please include all package versions used in requirements.txt (or include a copy of pip freeze)?

    Thank you for all your help! Roshan

    opened by roshankern 1
  • single node multi-GPU hangs

    single node multi-GPU hangs

    Hi, I am running SSL training on a single node with two GPUs. It runs only when --nproc_per_node=1. When I set nproc_per_node=2 it gets stuck after init for the second GPU.

    init_distributed_mode .... | distributed init (rank 0): env:// | distributed init (rank 1): env://

    setting dist_url to env://127.0.0.1 didn't fix it. I also tried --world_size=2.

    opened by memphizz 4
Owner
Sara Ahmed
Sara Ahmed
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models Codes for this paper The Lottery Tickets Hypo

VITA 59 Dec 28, 2022
The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Published by SpaceML • About SpaceML • Quick Colab Example Self-Supervised Learner The Self-Supervised Learner can be used to train a classifier with

SpaceML 92 Nov 30, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
Repository providing a wide range of self-supervised pretrained models for computer vision tasks.

Hierarchical Pretraining: Research Repository This is a research repository for reproducing the results from the project "Self-supervised pretraining

Colorado Reed 53 Nov 9, 2022
PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Self-Supervised Vision Transformers with DINO PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supe

Facebook Research 4.2k Jan 3, 2023
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022
EsViT: Efficient self-supervised Vision Transformers

Efficient Self-Supervised Vision Transformers (EsViT) PyTorch implementation for EsViT, built with two techniques: A multi-stage Transformer architect

Microsoft 352 Dec 25, 2022
PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

null 36 Oct 30, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Alex Pashevich 62 Dec 24, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 87 Nov 29, 2022
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

null 52 Dec 29, 2022
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 409 Jan 6, 2023
Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models

Patch-Rotation(PatchRot) Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models Submitted to Neurips2021 To

null 4 Jul 12, 2021
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 9, 2023
This is a Pytorch implementation of the paper: Self-Supervised Graph Transformer on Large-Scale Molecular Data.

This is a Pytorch implementation of the paper: Self-Supervised Graph Transformer on Large-Scale Molecular Data.

null 212 Dec 25, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 45 Dec 3, 2021