(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency ๐Ÿ˜ฑ Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 โœ• 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, โ†‘) and Acc for 90% certain results (Acc-90, โ†‘), negative log-likelihood (NLL, โ†“), Expected Calibration Error (ECE, โ†“), Intersection-over-Union (IoU, โ†‘) and IoU for certain results (IoU-90, โ†‘), Unconfidence (Unc-90, โ†‘), and Frequency for certain results (Freq-90, โ†‘). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes โš ๏ธ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing ๐Ÿ“‘ the paper and starring ๐ŸŒŸ this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Comments
  • What exactly makes MSAs data specificity?

    What exactly makes MSAs data specificity?

    In the paper, authors state that "A key feature of MSAs is data specificity (not long-range dependency)".

    Can you explain about the "data specificity" part? What is it, and how it behaves?

    Further more, can you elaborate how MSAs (through visualization, formulas, etc) achieves data specificity

    opened by iumyx2612 7
  • Conclusion about long-range dependency seems not true

    Conclusion about long-range dependency seems not true

    Authors stated in the paper that: "Contrary to popular belief, the long-range dependency hinders NN optimization.". However, recent models that adopts long-range dependency achieves really great results like: VAN, ConvNeXt or RepLKNet

    Therefore, the statement I mentioned above seems a little bit wrong? I know there's an issue that discuss about large kernel Conv, however, the issue did not mention the statement above.

    Moreover, the Experiments in Fig 7, you use Convolutional SANs. This model has 2 variants: 1D-CSANs and 2D-CSANs. The one you are doing experiments on is 2D-CSANs right? It not only consider the interaction among tokens in a single, but also consider the interaction among different heads. The "long-range dependency" is still very beneficial in the 1D-CSANs (Fig below), which typically, is what I consider the true "long-range dependency" in Self-attention.

    image

    When using 2D-CSANs, it considers both aspects: interaction among heads, and tokens, which brings negative performance when scaling up window sizes. The results is align with Convolutional SANs paper.

    image

    However, I don't consider 2D-CSANs negative performance when scaling up window sizes is: "long-range dependency hinders NN optimization" since it consider 2 aspects in the model. Sorry for writing this long, if you don't understand any parts in my question, I can clarify it for you

    opened by iumyx2612 6
  • model size

    model size

    hello๏ผŒi have aquestion about why you use vit-s and vit-tiny,and counterpart is resnet-50,these size is not equal.i know you have explained on openview,i want to know whether vit-base's matrix eigenvalue spectrum is like vit-tiny in your paper,just stretch to the right.

    opened by forever10086 5
  • Plot for Relative log amplitudes of Fourier transformed feature maps

    Plot for Relative log amplitudes of Fourier transformed feature maps

    Hi, thank you for the great paper. Could you please release the code or give implementation example of plotting "Relative log amplitudes of Fourier transformed feature maps". Thanks!

    opened by xingchenzhao 5
  • Code for Alter-ResNet-50

    Code for Alter-ResNet-50

    Hi, awesome work and really good points about MSAs! I'm very much interested in the AlterNet mentioned in the paper(based on ResNet-50 and SwinTBlock), but I cant find the implementation of it in this repo. Did I miss? If not, can you release the code maybe?

    Thanks a lot!

    opened by DarrenIm 5
  • Findings not compatible with other work?

    Findings not compatible with other work?

    In figure 1 of the paper, authors stated that MSA flattens the loss landscape, however, in When Vision Transformer outperform ResNets without pre-training or strong data augmentation, they stated that ViT converge at sharp local minima, which is contrast to your findings?

    Furthermore, authors claim that "The magnitude of the Hessian eigenvalues of ViT is smaller than that of ResNet during training phase" (Fig 1 still). However, in above paper, the "Hessian dominate eigenvalue" of ViT are "orders of magnitude larger than that of ResNet" (Table 1).

    Loss landscape and Hessian max eigenvalue of your work: image

    Loss landscape and Hessian max eigenvalue of other work: image

    opened by iumyx2612 4
  • how to compute feature map variances?

    how to compute feature map variances?

    hello,

    Thank you for your great work!

    I wonder how you get the feature map variances. According to my understanding, you first need to extract representations of all the samples, which should give us a vector with a length of D (let's just fatten the 2d tensor or concatenate all tokens). Then you calculate the variance of each element in this vector over all the samples, which should give us D variances. Finally, you take the mean value of all D variances and get the variance ready to report.

    Did I get you correctly? Sorry if I didn't catch up with your existing documentation or description.

    Thank you and I'm looking forward to your reply.

    Best,

    opened by LostXine 4
  • how is robustness calculated?

    how is robustness calculated?

    Hi,

    thank you for this wonderful work on vision transformers and how to understand them. I have some simple questions which I must apologize for. I tried to reproduce figure 12 independently of your code base. I struggle a bit to understand the code. Is is correct that you define robustness as robustness = mean(accuracy(y_val_true, y_val_pred))? Related to this, do I understand correctly that you compute this accuracy on batches of the validation dataset? These batches are of size 256, right?

    Thanks.

    opened by psteinb 4
  • Potential mistake in loss landscape visualization.

    Potential mistake in loss landscape visualization.

    Hi, thanks for your great work. I'd like to discuss the L2 Loss problem in loss landscape visualization. I found that your calculated L2 loss is significantly larger (10x) than the classification loss so the landscape visualization is basically a visualization of L2 Loss. In fact, "weight decay" is slightly different from "L2 Loss" in Pytorch in implementation. Simply calculating the sum of norms as L2 loss is different from applying weight decays in Adam-like fancy optimizers in Pytorch. See blogs in https://bbabenko.github.io/weight-decay/. Although one might find L2 Loss is significantly larger than the classification loss. In fact, in the practice of ViT, the weight decay loss does not dominate the classification loss, this is due to the implementation of weight decay in Pytorch.

    opened by sjtuytc 3
  •  code about robustness for noise frequency exp (Fig. 2b)

    code about robustness for noise frequency exp (Fig. 2b)

    ์•ˆ๋…•ํ•˜์„ธ์š”. ์ €์ž๋‹˜. ์šฐ์„  ๋งŽ์€ ์ธ์‚ฌ์ดํŠธ๋ฅผ ์ฃผ๋Š” ์ข‹์€ ๋…ผ๋ฌธ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

    ์ €์ž๋‹˜์˜ ๋…ผ๋ฌธ์„ ์ฝ๊ณ  ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ ๋ถ„์„์„ ์ง„ํ–‰ํ•ด ๋ณด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ์ค‘์— ์ €์ž๋‹˜์˜ Fig. 2b์˜ robustness for noise frequency์— ๋Œ€ํ•œ ๋ถ„์„์„ ์ง„ํ–‰ํ•ด ๋ณด๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ฝ”๋“œ์—๋Š” ์ด ๋ถ€๋ถ„์€ ์—†๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์—ฌ ์งˆ๋ฌธ๋“œ๋ฆฌ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

    ์•„๋งˆ๋„ FreqAttack ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ด๋Š”๋ฐ, ํ˜น์‹œ ์ด ์‹คํ—˜์„ ์žฌํ˜„ํ•ด๋ณด๊ธฐ ์œ„ํ•œ ๊ฐ frequency๋ณ„ random noise๋ฅผ ์ ์šฉํ•˜๋Š” ์‹คํ—˜ ์ฝ”๋“œ ๊ณต์œ ๋ฅผ ํ•ด์ฃผ์‹ค ์ˆ˜ ์žˆ์„๊นŒ์š”?

    ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

    opened by DoyoungYoon 2
  • Hi

    Hi

    When i run the forward function of LocalAttention class, some errors occurred.

    x.shape = [1,128,84,64] and self.window_size=8. The rearrange function can not run in the right way as n1=84//8 can not be divisible.

    If i change the window_size=7/6/5, there may be other img's height or width can not be divisible.

    I also try dynamic set window_size but it didn't succeed.

    The image come from coco datasets.

    Do you have any good suggestions ?

    The code is

          b, c, h, w = x.shape
    
            p = self.window_size
    
            n1 = h // p
    
            n2 = w // p
    
            mask = torch.zeros(p ** 2, p ** 2, device=x.device) if mask is None else mask
    
            mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]
    
            x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)
    
            x, attn = self.attn(x, mask)
    
            x = rearrange(x, "(b n1 n2) c p1 p2 -> b c (n1 p1) (n2 p2)", n1=n1, n2=n2, p1=p, p2=p)
    
    opened by ross-Hr 2
  • How to plot the Hessian max eigenvalue spectra?

    How to plot the Hessian max eigenvalue spectra?

    I read your paper and studied a lot. I would also like to see the code for plotting Hessian max eigenvalue spectra. May I know if you have any plans to update?

    Best,

    opened by Dong1P 3
  • In convit.py file, where does ConVit come from, really?

    In convit.py file, where does ConVit come from, really?

    https://github.com/xxxnell/how-do-vits-work/blob/8752f4e330a38877c628dfa40d57fa9404bb3131/models/convit.py#L1-L6

    You said it's not the same with ConVit by d'Ascoli, Stรฉphane, et al. Then where does this ConVit come from? I ask because if I reuse this code, I want to know whom I should cite.

    opened by dinhanhx 15
  • what is the attributes in the large-kernel CNN

    what is the attributes in the large-kernel CNN

    Great analysis! I wonder the attributes of large-kernel CNN. In your paper, the basic 3x3 resnet is fully explored. 3x3 conv extracts detailed local patterns, thus may contribute to the high pass filtering. However, recent works investigate the effect of larger kernel. The attribute of 3x3 resnet might change, and similar to ViT?

    opened by ccx1997 2
Owner
xxxnell
Programmer & ML researcher
xxxnell
[Preprint] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang

Chasing Sparsity in Vision Transformers: An End-to-End Exploration Codes for [Preprint] Chasing Sparsity in Vision Transformers: An End-to-End Explora

VITA 64 Dec 8, 2022
[Preprint] "Bag of Tricks for Training Deeper Graph Neural Networks A Comprehensive Benchmark Study" by Tianlong Chen*, Kaixiong Zhou*, Keyu Duan, Wenqing Zheng, Peihao Wang, Xia Hu, Zhangyang Wang

Bag of Tricks for Training Deeper Graph Neural Networks: A Comprehensive Benchmark Study Codes for [Preprint] Bag of Tricks for Training Deeper Graph

VITA 101 Dec 29, 2022
[Preprint] ConvMLP: Hierarchical Convolutional MLPs for Vision, 2021

Convolutional MLP ConvMLP: Hierarchical Convolutional MLPs for Vision Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision By Jiachen Li

SHI Lab 143 Jan 3, 2023
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... ๋ชจ๋ธ์˜ ๊ฐœ๋…์ดํ•ด๋ฅผ ๋•๊ธฐ ์œ„ํ•œ ๊ตฌํ˜„๋ฌผ๋กœ ํ˜„์žฌ ๋ณ€์ˆ˜๋ช…์„ ์ƒ์„ธํžˆ ์ ์—ˆ๊ณ 

BG Kim 3 Oct 6, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes.

NVIDIA Research Projects 3.2k Dec 30, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richardย Wang 443 Dec 6, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling ฤorฤ‘e Miladinoviฤ‡ โ€ƒ Aleksandar Staniฤ‡ โ€ƒ Stefan Bauer โ€ƒ Jรผrgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
StyleGAN2-ADA - Official PyTorch implementation

Need Help? If youโ€™re new to StyleGAN2-ADA and looking to get started, please check out this video series from a course Lia Coleman and I taught in Oct

Derrick Schultz 217 Jan 4, 2023
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

null 123 Dec 27, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022