Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)

Overview

Vision Transformer

Pytorch reimplementation of Google's repository for the ViT model that was released with the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

This paper show that Transformers applied directly to image patches and pre-trained on large datasets work really well on image recognition task.

fig1

Vision Transformer achieve State-of-the-Art in image recognition task with standard Transformer encoder and fixed-size patches. In order to perform classification, author use the standard approach of adding an extra learnable "classification token" to the sequence.

fig2

Usage

1. Download Pre-trained model (Google's Official Checkpoint)

  • Available models: ViT-B_16(85.8M), R50+ViT-B_16(97.96M), ViT-B_32(87.5M), ViT-L_16(303.4M), ViT-L_32(305.5M), ViT-H_14(630.8M)
    • imagenet21k pre-train models
      • ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14
    • imagenet21k pre-train + imagenet2012 fine-tuned models
      • ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32
    • Hybrid Model(Resnet50 + Transformer)
      • R50-ViT-B_16
# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz

# imagenet21k pre-train + imagenet2012 fine-tuning
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz

2. Train Model

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

CIFAR-10 and CIFAR-100 are automatically download and train. In order to use a different dataset you need to customize data_utils.py.

The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of --gradient_accumulation_steps.

Also can use Automatic Mixed Precision(Amp) to reduce memory usage and train faster

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --fp16 --fp16_opt_level O2

Results

To verify that the converted model weight is correct, we simply compare it with the author's experimental results. We trained using mixed precision, and --fp16_opt_level was set to O2.

imagenet-21k

model dataset resolution acc(official) acc(this repo) time
ViT-B_16 CIFAR-10 224x224 - 0.9908 3h 13m
ViT-B_16 CIFAR-10 384x384 0.9903 0.9906 12h 25m
ViT_B_16 CIFAR-100 224x224 - 0.923 3h 9m
ViT_B_16 CIFAR-100 384x384 0.9264 0.9228 12h 31m
R50-ViT-B_16 CIFAR-10 224x224 - 0.9892 4h 23m
R50-ViT-B_16 CIFAR-10 384x384 0.99 0.9904 15h 40m
R50-ViT-B_16 CIFAR-100 224x224 - 0.9231 4h 18m
R50-ViT-B_16 CIFAR-100 384x384 0.9231 0.9197 15h 53m
ViT_L_32 CIFAR-10 224x224 - 0.9903 2h 11m
ViT_L_32 CIFAR-100 224x224 - 0.9276 2h 9m
ViT_H_14 CIFAR-100 224x224 - WIP

imagenet-21k + imagenet2012

model dataset resolution acc
ViT-B_16-224 CIFAR-10 224x224 0.99
ViT_B_16-224 CIFAR-100 224x224 0.9245
ViT-L_32 CIFAR-10 224x224 0.9903
ViT-L_32 CIFAR-100 224x224 0.9285

shorter train

  • In the experiment below, we used a resolution size (224x224).
  • tensorboard
upstream model dataset total_steps /warmup_steps acc(official) acc(this repo)
imagenet21k ViT-B_16 CIFAR-10 500/100 0.9859 0.9859
imagenet21k ViT-B_16 CIFAR-10 1000/100 0.9886 0.9878
imagenet21k ViT-B_16 CIFAR-100 500/100 0.8917 0.9072
imagenet21k ViT-B_16 CIFAR-100 1000/100 0.9115 0.9216

Visualization

The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module. The attention map for the input image can be visualized through the attention score of self-attention.

Visualization code can be found at visualize_attention_map.

fig3

Reference

Citations

@article{dosovitskiy2020,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
Comments
  • Model Architecture For Fine-tuning

    Model Architecture For Fine-tuning

    In the original paper, the authors state that "we remove the whole head (two linear layers) and replace it by a single, zero-initialized linear layer outputting the number of classes required by the target dataset. We found this to be a little more robust than simply re-initializing the very last layer."

    May I know which code snippet is related to this?

    opened by chaoyanghe 11
  • Low training speed on RTX 3090

    Low training speed on RTX 3090

    Training on the 3090 gets slower and slower as time goes on but the 2080ti doesn't have this problem

    torch 1.8.0.dev20201130+cu110 torchvision 0.9.0.dev20201130+cu110 NVIDIA-SMI 455.23.04 Driver Version: 455.23.04 CUDA Version: 11.1

    opened by lingorX 4
  • The Encoder implementation is different from the original

    The Encoder implementation is different from the original "Attention is all need" paper?

    Hi, I checked your code at https://github.com/jeonsworld/ViT-pytorch/blob/878ebc5bd12255d2fffd6c0257b83ee075607a79/models/modeling.py#L154.

    Your implementation is: Attention(LayerNorm(x)) + x, but the original Transformer is: LayerNorm(x +Attention(x)). Is this an error or deliberately implemented like this?

    opened by chaoyanghe 4
  • how you save tensorboard?

    how you save tensorboard?

    @jeonsworld okay this is completely different questions, but I should ask it because i have not seen it anywhere else. how did you save tensorboard so we can just click on it and see it? I want to do it as well. should I save it in some format or do anything special? please direct me to any link/material that can help me with that. Thanks :)

    opened by seyeeet 3
  • KeyError: 'Transformer/encoderblock_0\\MultiHeadDotProductAttention_1/query\\kernel is not a file in the archive'

    KeyError: 'Transformer/encoderblock_0\\MultiHeadDotProductAttention_1/query\\kernel is not a file in the archive'

    when i used code,the error occurs error location:

     models\modeling.py", line 195, in load_from
    query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
    File "d:\Anaconda3\lib\site-packages\numpy\lib\npyio.py", line 259, in __getitem__
    raise KeyError("%s is not a file in the archive" % key)
    KeyError: 'Transformer/encoderblock_0\\MultiHeadDotProductAttention_1/query\\kernel is not a file in the archive'
    

    I would like to ask where should I put this ViT-H_14.npz ? I created a checkpint folder and just put the ViT-H_14.npz in there,but I got this error。 the INFO:01/12/2021 19:51:55 - INFO - models.modeling - load_pretrained: resized variant: torch.Size([1, 257, 1280]) to torch.Size([1, 730, 1280]) my input: imgsize(384*384),batch.size(64){train.batch=eval.batch}. Is there anything I haven't modified?

    opened by tianle-BigRice 3
  • HTTP Error 403: Forbidden

    HTTP Error 403: Forbidden

    Hi

    I tried your notebook but the link is dead I think. I got the error of forbidden.

    HTTPError                                 Traceback (most recent call last)
    <ipython-input-4-14f159ea9fa3> in <module>
          1 # Test Image
          2 img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg"
    ----> 3 urlretrieve(img_url, "attention_data/img.jpg")
          4 
          5 # Prepare Model
    ...
    HTTP Error 403: Forbidden
    
    opened by vietvo89 2
  • Request for pre-trained weights only on Imagenet2012.

    Request for pre-trained weights only on Imagenet2012.

    Thanks for your hard work! I wonder if there are some pre-trained weights only using Imagenet2012? I found that the pre-trained ResNet provided by torchvision may be pre-trained only on Imagenet2012 so I want to take ViT and ResNet for a fair comparison.

    opened by JingyeChen 2
  • ImportError: cannot import name 'UnencryptedCookieSessionFactoryConfig' from 'pyramid.session' (unknown location)

    ImportError: cannot import name 'UnencryptedCookieSessionFactoryConfig' from 'pyramid.session' (unknown location)

    Hi, by executing this

    python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz
    

    I encounter the error:

    Traceback (most recent call last):
      File "train.py", line 17, in <module>
        from apex import amp
      File "/home/tiger/.local/lib/python3.7/site-packages/apaex/__init__.py", line 13, in <module>
        from pyramid.session import UnencryptedCookieSessionFactoryConfig
    ImportError: cannot import name 'UnencryptedCookieSessionFactoryConfig' from 'pyramid.session' (unknown location)
    
    opened by yhangchen 2
  • Loss can't drop

    Loss can't drop

    Thank you so much for sharing your codes. I try to employ Vit as the encoder and follow a common decoder to build a segmentation network. I train it from scratch but found the loss can't drop since the beginning of training, and the results keep near 0. Is there any trick for training Vit correctly? Is it very important to load the pre-train model to fine-tune? Here is my configuration: patch_size=16 hidden_size=16*16*3 mlp_dim = 3072 dropout_rate = 0.1 num_heads = 12 num_layers = 12 lr=3e-4 opt=Adam weight_decay=0.0

    opened by QiushiYang 2
  • About imagenet-21k

    About imagenet-21k

    Thanks for your great repo !

    I cannot find the link to download the imagenet-21k dataset, so is there any way to download Imagenet-21k now? Thanks a lot~

    opened by zhangzjn 2
  • Training accuracy much lower than validation accuracy

    Training accuracy much lower than validation accuracy

    Thanks for creating and uploading this easily usable repo!

    In addition to the validation accuracy on the entire validation set that is printed out by default, we printed out the training accuracies of the model and we observe that the training accuracy is 6-8% lower than the validation accuracy. Is that reasonable/accurate since we usually expect the training accuracy to be higher than the validation accuracy?

    This was for a ViT-B_16 model, pretrained on ImageNet-21k and during the fine-tuning phase on CIFAR 10. To get the training accuracies, we used model(x)[0] to get the logits, loss and predictions for each batch and used the AverageMeter() to calculate the running accuracies. Additionally, to get the accurate training accuracy over the entire training set, we passed the training set to a copy valid() (with only changes to print statements). Both the running training accuracy and the training accuracy over the entire training set was lower than the validation accuracy by 6-8%. For instance, after 10k steps, training accuracy was 92.9% (over entire train set) and validation accuracy was 98.7%. We used most of the default hyperparameters (besides batch size and fp_16) and did not make other changes to the code.

    Please let us know if this lower training accuracy is expected or if its calculation is incorrect. Thanks in advance.

    opened by ganeshkumar5699 1
  • Loss doesn't drop in the example

    Loss doesn't drop in the example

    Hi, thanks for releasing this code.

    I have tried to run the CIFAR-10 (as well as CIFAR-100) example, but in both cases the validation (and training) loss do not decrease, and the validation accuracy gets stuck in 0.01. Is there any hyper-parameter that I need to change from the example code?

    Thanks!

    image

    opened by josedolz 0
  • Why the model gives the same logits for both the classes?

    Why the model gives the same logits for both the classes?

    Hi, I am using ViT-H_14 pre-trained to perform binary classification of biomedical images. The dataset I have available is very small: I use about 300 images to perform fine tuning and about 30 images for validation. The goal is to classify the images based on the aggressiveness of the tumor represented (Low grade (0) - High grade(1)). However, I noticed that during the prediction, each image is always associated with the label 0, and going to look on the logits, i found that are always produced logits identical pairs (eg [[ 6.877057e-10 -6.877057e-10]]), which are translated into probability pairs of about (0.49,0.51).

    Searching the various forums I found many different tips: vary the learning rate (which I decreased to 1e-8), decrease the batch size (from 8 to 2), etc.. Unfortunately none of this works. The last thing I want to try is to increase considerably the number of epochs (at the moment I have trained for only 100 epochs), but before doing so I wanted to see if someone had a more specific suggestion, or even if someone can tell me if this architecture is too much for a dataset so small.

    Thanks a lot in advance

    opened by Evap6 0
Owner
Eunkwang Jeon
Eunkwang Jeon
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
PyTorch reimplementation of the Smooth ReLU activation function proposed in the paper "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" [arXiv 2022].

Smooth ReLU in PyTorch Unofficial PyTorch reimplementation of the Smooth ReLU (SmeLU) activation function proposed in the paper Real World Large Scale

Christoph Reich 10 Jan 2, 2023
text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.

text recognition toolbox 1. 项目介绍 该项目是基于pytorch深度学习框架,以统一的改写方式实现了以下6篇经典的文字识别论文,论文的详情如下。该项目会持续进行更新,欢迎大家提出问题以及对代码进行贡献。 模型 论文标题 发表年份 模型方法划分 CRNN 《An End-t

null 168 Dec 24, 2022
PyTorch reimplementation of the paper Involution: Inverting the Inherence of Convolution for Visual Recognition [CVPR 2021].

Involution: Inverting the Inherence of Convolution for Visual Recognition Unofficial PyTorch reimplementation of the paper Involution: Inverting the I

Christoph Reich 100 Dec 1, 2022
Unofficial PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution

PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution [arXiv 2021].

Christoph Reich 122 Dec 12, 2022
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
Reimplementation of Dynamic Multi-scale filters for Semantic Segmentation.

Paddle implementation of Dynamic Multi-scale filters for Semantic Segmentation.

Hongqiang.Wang 2 Nov 1, 2021
Unofficial reimplementation of ECAPA-TDNN for speaker recognition (EER=0.86 for Vox1_O when train only in Vox2)

Introduction This repository contains my unofficial reimplementation of the standard ECAPA-TDNN, which is the speaker recognition in VoxCeleb2 dataset

Tao Ruijie 277 Dec 31, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)

Vision Transformer for Fast and Efficient Scene Text Recognition (ICDAR 2021) ViTSTR is a simple single-stage model that uses a pre-trained Vision Tra

Rowel Atienza 198 Dec 27, 2022
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 6, 2023
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022
PyTorch reimplementation of hand-biomechanical-constraints (ECCV2020)

Hand Biomechanical Constraints Pytorch Unofficial PyTorch reimplementation of Hand-Biomechanical-Constraints (ECCV2020). This project reimplement foll

Hao Meng 59 Dec 20, 2022
A PyTorch Reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution

TecoGAN-PyTorch Introduction This is a PyTorch reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution (VSR). Please refer to

null 165 Dec 17, 2022
a reimplementation of Optical Flow Estimation using a Spatial Pyramid Network in PyTorch

pytorch-spynet This is a personal reimplementation of SPyNet [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 269 Jan 2, 2023
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022
a reimplementation of UnFlow in PyTorch that matches the official TensorFlow version

pytorch-unflow This is a personal reimplementation of UnFlow [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 134 Nov 20, 2022