DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Overview

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Cho-Jui Hsieh

This repository contains PyTorch implementation for DynamicViT.

We introduce a dynamic token sparsification framework to prune redundant tokens in vision transformers progressively and dynamically based on the input:

intro

Our code is based on pytorch-image-models, DeiT and LV-ViT

[Project Page] [arXiv]

Model Zoo

We provide our DynamicViT models pretrained on ImageNet:

name arch rho acc@1 acc@5 FLOPs url
DynamicViT-256/0.7 deit_256 0.7 76.532 93.118 1.3G Google Drive / Tsinghua Cloud
DynamicViT-384/0.7 deit_small 0.7 79.316 94.676 2.9G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.5 lvvit_s 0.5 81.970 95.756 3.7G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.7 lvvit_s 0.7 83.076 96.252 4.6G Google Drive / Tsinghua Cloud
DynamicViT-LV-M/0.7 lvvit_m 0.7 83.816 96.584 8.5G Google Drive / Tsinghua Cloud

Usage

Requirements

  • torch>=1.7.0
  • torchvision>=0.8.1
  • timm==0.4.5

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Model preparation: download pre-trained DeiT and LV-ViT models for training DynamicViT:

sh download_pretrain.sh

Demo

We provide a Jupyter notebook where you can run the visualization of DynamicViT.

To run the demo, you need to install matplotlib.

demo

Evaluation

To evaluate a pre-trained DynamicViT model on ImageNet val with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --model-path /path/to/model --base_rate 0.7 

Training

To train DynamicViT models on ImageNet, run:

DeiT-small

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_deit-small --arch deit_small --input-size 224 --batch-size 96 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-S

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-s --arch lvvit_s --input-size 224 --batch-size 64 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-M

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-m --arch lvvit_m --input-size 224 --batch-size 48 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

You can train models with different keeping ratio by adjusting base_rate. DynamicViT can also achieve comparable performance with only 15 epochs training (around 0.1% lower accuracy).

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021dynamicvit,
  title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
  author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
  journal={arXiv preprint arXiv:2106.02034},
  year={2021}
}
Comments
  • 关于Attention Masking创新点的疑问

    关于Attention Masking创新点的疑问

    关于论文中Attention Masking,对于这个创新点有比较大的疑惑。因为个人觉得只需要根据mask,把token的内容置位-inf就可以了,不需要交互结束之后,再在softmax过程中加上用mask进行制约。

    另外,这篇文章提出的方法是不是只能在预测的时候减少计算量,因为在训练的时候进行Self-Attention,还是所有token都进行了Self-Attention,这样的文中的计算量就没有下降。反而因为加入了预测模块,计算量还上升了。

    opened by xmu-xiaoma666 3
  • [Questions] Dyswin feature output shape

    [Questions] Dyswin feature output shape

    Hi~ Thx for your great work and excellent code!

    I have some questions about dyswin code, hope you could help me out.

    1. In code: https://github.com/raoyongming/DynamicViT/blob/84b4e2a9b1f11199bd1e2ff506969b0d64e6f55b/models/dyswin.py#L683

      when I input a tensor with shape (2, 3, 224,224), the if condition is activated(len(x) is equal to 2, and the following operation obviously goes wrong.

      However, when I set batchsize to another number, this error is disappeared. Could you please explain the code here?

    2. When I use lvvit-s pre-trained model as inference backbone, I find that the output token length of lvvit-s is cut shorter compared with standard lvvit output token length, and it's correct, right?

      But when I change the backbone from lvvit-s to swin-s, the output token length is the same as standard swin-s

      For example, if input tensor with shape (4, 3,224,224), the output shape of dyswin(temporarily ignore the avgpool and later layers) is (4, 49, 768 ) while the standard swin-s is also output this shape tensor. It seems that token length reduction has not been achieved.

      Could you please explain the here? Any advice could be greatly appreciate! :)

    opened by zafirshi 2
  • Is the subscript 'i' of Z_global in equation 4 a mis-type?

    Is the subscript 'i' of Z_global in equation 4 a mis-type?

    The shape of Z_global is (C,) and the shape of Z_local is (N,C).

    I have checked your code in line 263 of DynamicViT/models/dyswin.py, I think your idea is to concatenate of Z_local_i with Z_global? Since there is no dimension N in Z_global.

    So I think the subscript 'i' of Z_global in equation 4 in your paper maybe a mis-type?

    opened by LucasZhan 2
  • Some questions about your code

    Some questions about your code

    Hello, thank you for your code. I spent some time reading your code carefully. But I still can not understand the following lines(https://github.com/raoyongming/DynamicViT/blob/84b4e2a9b1f11199bd1e2ff506969b0d64e6f55b/models/dyvit.py#L174). Could you please give me sone advices? Why abstract the max value from the atten?

    opened by leoozy 2
  • Flops tools

    Flops tools

    Hi, it is wonderful and solid work. I have several questions about Flops. In your paper, you compute the model's flops (the unit is Gflops). Which package can compute the Gflops? The popular package from ptflops import get_model_complexity_info as I know, its unit is output as MAc not Gflops. Thanks in advance.

    opened by waynelrs 2
  • DynamicVIT training stored checkpoint

    DynamicVIT training stored checkpoint

    Hi, thanks for the great work! I am training the dvit for the deit-small, and the checkpoint it stored has size of 364 MB, while the original model and the weights you shared are ~90MB. Am I making an error?

    opened by SwapnilDreams100 2
  • Can't reproduce the accuracy of pre-trained models

    Can't reproduce the accuracy of pre-trained models

    Tried arch: deit_small, deit_256 Dataset: Imagenet-1k-val File structure:

    │ILSVRC2012_val/
    ├──val/
    │  ├── 1(image label)
    │  │   ├── ILSVRC2012_val_00000293.JPEG
    │  │   ├── ILSVRC2012_val_00002138.JPEG
    │  │   ├── ......
    │  ├── 2(image label)
    │  │   ├── ILSVRC2012_val_00000293.JPEG
    │  │   ├── ILSVRC2012_val_00002138.JPEG
    │  │   ├── ......
    

    When I ran python3 infer.py --data-path /home/ubuntu/datasets/ILSVRC2012_val/ --arch deit_small --model-path /home/ubuntu/models/dynamic-vit_384_r0.7.pth --base_rate 0.7, the result is Acc@1 0.080 Acc@5 0.582.

    The filenames ( image label ) are decided by ILSVRC2012_validation_ground_truth.txt in the development kit. Is this problem due to the wrong file name, which causes the model to predict a different label than the real one? Should I modify the filename to WNID? But the Val dataset has no WNID, how could I confirm that?

    Thanks

    opened by xiyiyia 2
  • About distill

    About distill

    Why LVViT_Teacher return aux_head(x[:, 1:]) instead of tokens?

    LVViT_Teacher:

    x = self.norm(x)
    x_cls = self.head(x[:,0])
    x_aux = self.aux_head(x[:,1:])
    return x_cls, x_aux
    

    VisionTransformerTeacher:

    feature = self.norm(x)
    cls = feature[:, 0]
    tokens = feature[:, 1:]
    cls = self.pre_logits(cls)
    cls = self.head(cls)
    return cls, tokens
    

    And, can I use the LVViT_Teacher to distill deit_small?

    opened by hegc 2
  • Attention mask computation during training

    Attention mask computation during training

    Hello,

    Thank you for your work. I'm reading your implementation code for computing the attention mask. In lvvit.py line 665, you used F.gumbel_softmax(pred_score) to find the binary mask, but the last layer for computing pred_score is a log_softmax (line 529, same file). Shouldn't it be just a linear unit rather than a log_softmax, since F.gumbel_softmax takes logit rather than probability?

    Thank you.

    opened by mtchiu2 2
  • pretrained model download

    pretrained model download

    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1iZc6d27EuEnlfUpJoNhsZEkt6GVgPy7-' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1iZc6d27EuEnlfUpJoNhsZEkt6GVgPy7-" -O lvvit_m-56M-224-84.0.tar && rm -rf /tmp/cookies.txt I can use the code above download lvvit_m-56M-224-84.0.tar But why I can't use !tar -xvf lvvit_m-56M-224-84.0.tar to extract the file It says

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by dirtycomputer 2
  • Loss is nan when training my own dataset

    Loss is nan when training my own dataset

    It would happen randomly in any epoch. And it still appears after setting a lower learning rate or turning off the amp.

    Then I found it usually occurs in module 'PredictorLG', if 'policy' is an all-zero matrix, there will be nan in global_x.

    Is it ok to add a very small value in the denominator of the global_x, e.g. 1e-6?

    opened by InfinityBox 2
  • Fail to reproduce accuracy of DynamicViT-B/0.7:   lower accuracy than reported

    Fail to reproduce accuracy of DynamicViT-B/0.7: lower accuracy than reported

    Hi, I follow the training command:

    python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_deit-b --model deit-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3

    and get the final results:

    • Acc@1 80.754 Acc@5 94.950 loss 0.888 Accuracy of the model on the 50000 test images: 80.8% Max accuracy: 80.80%

    • Acc@1 80.760 Acc@5 94.970 loss 0.887 Accuracy of the model EMA on 50000 test images: 80.8% Max EMA accuracy: 80.86%

    fail to reproduce 81.3% for deit-base with 0.7 keeping ratio reported in the paper (~0.5% drop). But I got 79.2% for deit-small (79.3% in the paper). These two experiments were conducted in the same python envrionments.

    My envrionments: python 3.10 torch 1.12.1 torchvision 0.13.1 timm 0.3.2

    Could you provide the training instructions/checkpoints that achieved ~81.3% accuracy for deit-base?

    Thanks!

    opened by ShiFengyuan1999 8
  • Temperature in Gumbel Softmax

    Temperature in Gumbel Softmax

    Hi, thanks for your inspiring work!

    I notice that you used the default temperature=1 in all your F.gumbel_softmax implementations, and it didn't anneal to 0. Do you have any suggestions on why should we fix this temperature? Because I thought shouldn't we decrease this temperature during training to make it closer and closer to the real categorical distribution, as indicated in the Gumbel Softmax paper?

    opened by kaikai23 1
  • Structural downsampling and static token sparsification

    Structural downsampling and static token sparsification

    Hi, it's a quite solid and promising work but I have some questions. (1) In the paper, you perform an average pooling with kernel size 2 × 2 after the sixth block for the structural downsampling. But in Table 3, you show the results of structural downsampling and static dynamic token sparsification. What is the difference between structural downsampling and static token sparsification since their ACCs are not same? (2) I'm interested in the average pooling with kernel size 2 × 2. Did you do extra experiments in the position of such structural downsampling, like the seventh block or the tenth block in ViT? (3) Could you provide the codes for reproducing the results of structural downsampling and static token sparsification in Table 3 and the probability heat-map in Figure 6?

    Thanks for your help!

    opened by Yeez-lee 2
SparseML is a libraries for applying sparsification recipes to neural networks with a few lines of code, enabling faster and smaller models

SparseML is a toolkit that includes APIs, CLIs, scripts and libraries that apply state-of-the-art sparsification algorithms such as pruning and quantization to any neural network. General, recipe-driven approaches built around these algorithms enable the simplification of creating faster and smaller models for the ML performance community at large.

Neural Magic 1.5k Dec 30, 2022
Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer This repository contains the PyTorch code for Evo-ViT. This work proposes a slow-fas

YifanXu 53 Dec 5, 2022
EsViT: Efficient self-supervised Vision Transformers

Efficient Self-Supervised Vision Transformers (EsViT) PyTorch implementation for EsViT, built with two techniques: A multi-stage Transformer architect

Microsoft 352 Dec 25, 2022
Code for "Searching for Efficient Multi-Stage Vision Transformers"

Searching for Efficient Multi-Stage Vision Transformers This repository contains the official Pytorch implementation of "Searching for Efficient Multi

Yi-Lun Liao 62 Oct 25, 2022
Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

TR-BERT Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference". The code is based on huggaface's transformers.

THUNLP 37 Oct 30, 2022
Dynamic View Synthesis from Dynamic Monocular Video

Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer This repository contains code to compute depth from a

Intelligent Systems Lab Org 2.3k Jan 1, 2023
Dynamic View Synthesis from Dynamic Monocular Video

Dynamic View Synthesis from Dynamic Monocular Video Project Website | Video | Paper Dynamic View Synthesis from Dynamic Monocular Video Chen Gao, Ayus

Chen Gao 139 Dec 28, 2022
Dynamic vae - Dynamic VAE algorithm is used for anomaly detection of battery data

Dynamic VAE frame Automatic feature extraction can be achieved by probability di

null 10 Oct 7, 2022
Learned Token Pruning for Transformers

LTP: Learned Token Pruning for Transformers Check our paper for more details. Installation We follow the same installation procedure as the original H

Sehoon Kim 52 Dec 29, 2022
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

null 210 Dec 18, 2022
Official code for paper "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight"

Demysitifing Local Vision Transformer, arxiv This is the official PyTorch implementation of our paper. We simply replace local self attention by (dyna

null 138 Dec 28, 2022
DeiT: Data-efficient Image Transformers

DeiT: Data-efficient Image Transformers This repository contains PyTorch evaluation code, training code and pretrained models for DeiT (Data-Efficient

Facebook Research 3.2k Jan 6, 2023
Official implementation of "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers"

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Figure 1: Performance of SegFormer-B0 to SegFormer-B5. Project page

NVIDIA Research Projects 1.4k Dec 31, 2022
Efficient Training of Visual Transformers with Small Datasets

Official codes for "Efficient Training of Visual Transformers with Small Datasets", NerIPS 2021.

Yahui Liu 112 Dec 25, 2022
Efficient Training of Audio Transformers with Patchout

PaSST: Efficient Training of Audio Transformers with Patchout This is the implementation for Efficient Training of Audio Transformers with Patchout Pa

null 165 Dec 26, 2022
Official Implementation of DE-CondDETR and DELA-CondDETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-CondDETR and DELA-Cond

Wen Wang 41 Dec 12, 2022
Official Implementation of DE-DETR and DELA-DETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-DETR and DELA-DETR in

Wen Wang 61 Dec 12, 2022
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

powerycy 40 Dec 14, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 4, 2023