Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Related tags

Deep Learning Twins
Overview

Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Very recently, a variety of vision transformer architectures for dense prediction tasks have been proposed and they show that the design of spatial attention is critical to their success in these tasks. In this work, we revisit the design of the spatial attention and demonstrate that a carefully-devised yet simple spatial attention mechanism performs favourably against the state-of-the-art schemes. As a result, we propose two vision transformer architectures, namely, Twins- PCPVT and Twins-SVT. Our proposed architectures are highly-efficient and easy to implement, only involving matrix multiplications that are highly optimized in modern deep learning frameworks. More importantly, the proposed architectures achieve excellent performance on a wide range of visual tasks including image- level classification as well as dense detection and segmentation. The simplicity and strong performance suggest that our proposed architectures may serve as stronger backbones for many vision tasks.

Twins-SVT-S Figure 1. Twins-SVT-S Architecture (Right side shows the inside of two consecutive Transformer Encoders).

Model Zoo

Image Classification

We provide baseline Twins models pretrained on ImageNet 2012.

Name Alias in paper acc@1 FLOPs(G) #params (M) url
PVT+CPVT-Small Twins-PCPVT-S 81.2 3.7 24.1 pcpvt_small.pth
PVT+CPVT-Base Twins-PCPVT-B 82.7 6.4 43.8 pcpvt_base.pth
ALT-GVT-Small Twins-SVT-S 81.3 2.8 24 alt_gvt_small.pth
ALT-GVT-Base Twins-SVT-B 83.1 8.3 56 alt_gvt_base.pth
ALT-GVT-Large Twins-SVT-L 83.3 14.8 99.2 alt_gvt_large.pth

^ Note: Our code will be released soon.

Citation

@article{chu2021Twins,
	title={Twins: Revisiting the Design of Spatial Attention in Vision Transformers},
	author={Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
	journal={Arxiv preprint 2104.13840},
	url={https://arxiv.org/pdf/2104.13840.pdf},
	year={2021}
}
Comments
  • can not reproduce the performance of svt-small model

    can not reproduce the performance of svt-small model

    Thanks for your nice work! And I would like to reproduce the performance of svt-small(alt_gvt_small) model. Below is my code:

    python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model alt_gvt_small --batch-size 256 --data-path ../data/ImageNet --dist-eval --drop-path 0.2

    The other parameters are default. But the result only up to 81.1%, not 81.7%.
    Could you give me some suggestions on how to reproduce your nice performance from scratch?

    opened by Yangr116 6
  • Why 'ws 1 for stand attention' in your GroupAttention code?

    Why 'ws 1 for stand attention' in your GroupAttention code?

    I find that in your implementation of GroupAttention in gvt.py, you comment that 'ws 1 for stand attention'.

    class GroupAttention(nn.Module):
        def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1, sr_ratio=1.0):
            """
            ws 1 for stand attention
            """
            super(GroupAttention, self).__init__()
            assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
    

    However, I think ws means the window size, if ws=1, than the self-attention is only performed in a 1x1 window, which is not the standard self-attention.

    opened by kejie-cn 5
  • 请教调用ALTGVT问题

    请教调用ALTGVT问题

    您好,请教一下 调用分类任务中的ALTGVT时,提示 类GroupBlock重载父类TimmBlock时多写了一个参数,报错如下: super(GroupBlock, self).init(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, TypeError: init() takes from 3 to 10 positional arguments but 11 were given

    查询TimmBlock后发现 其并没有qk_scale这个参数,请问这个bug如何解决呢

    opened by wfs123456 4
  • mIoU eval method (got higher mIoU than provided)

    mIoU eval method (got higher mIoU than provided)

    For the ALTGVT-Large model on ADE20k dataset, the reported single scale mIoU in the page says 48.8 However, I used the mmsegmentation default evaluation code and got 49.07 (higher??)

    (I am using default test_pipeline img_scale= (2048, 512) , mode='whole') I used a single GPU for single scale inference.

    Could you please have a look at this?

    Thank you

    opened by zbwxp 3
  • Some confusion about warmup strategy

    Some confusion about warmup strategy

    In the log you provided, I find that the warmup epochs number is not 5 and linear warmup start at the second epoch. This is inconsistent with paper:

    We use a linear warm-up in the first five epochs ...

    https://github.com/Meituan-AutoML/Twins/blob/37f9dbf1aa2181062f1ce880b952af2875f1b79f/logs/svt_s.txt#L1-L7

    Why? Thx.

    opened by TingquanGao 2
  • Two gvt.py in your repository, what the difference between  them?

    Two gvt.py in your repository, what the difference between them?

    I find there are two gvt.py in your repository. One is in the main content, another one is in the segmentation content. By carefully comparing the two py files, I found that the calculation of group attention in gvt.py in the segmentation directory is different. One adds attn. Mask in gvt.py and the other does not. So I want to ask, attn. Mask What is the role of the calculation of group attention? Why do you do this? Which gvt.py should I use when I'm doing a segmentation task?

    opened by JackeyGHD1 2
  • Runtime error for mmseg

    Runtime error for mmseg

    RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).

    Hi! Thanks for opensourcing code. When I use the Twins backbone in mmseg, I found this error. It seems that there are several parameters do not generate the loss.

    opened by lxtGH 2
  • Question about GSA

    Question about GSA

    Hello, thank you very much for your excellent work. I have some questions about GSA. According to my personal understanding, GSA in the paper takes one representation from each window, so the sr_ratio should be the same as the window size ([7, 7, 7, 7]) when calculating Key and Value, but it is [8, 4, 2, 1] in the code. Is there anything wrong with my understanding?

    @BACKBONES.register_module()
    class alt_gvt_large(ALTGVT):
        def __init__(self, **kwargs):
            super(alt_gvt_large, self).__init__(
                patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
                norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1],
                extra_norm=True, drop_path_rate=0.3,
            )
    
    opened by kejie-cn 2
  • Run Train in Windows10,ERROR

    Run Train in Windows10,ERROR

    run train script in windows10,error :

    yapf.yapflib.verifier.InternalError: (unicode error) 'unicodeescape' codec can't decode bytes in position 9-10: truncated \uXXXX escape (, line 1)

    opened by KangolHsu 2
  • 学习率

    学习率

    https://github.com/Meituan-AutoML/Twins/blob/4700293a2d0a91826ab357fc5b9bc1468ae0e987/main.py#L263

    在代码中,lr会跟随batch size线性调整。 但wramup-lr和min-lr为什么不用调整呢?

    希望能够得到帮助,感谢~~~

    opened by pawopawo 1
  • flops calculation is not accurate.

    flops calculation is not accurate.

    Hi, the get_flops.py doesn't consider the flops of self-attention, which is not accurate. For Twins-PCPVT-S: use get_flops.py given:

    ==============================
    Input shape: (3, 512, 2048)
    Flops: 162.66 GFLOPs
    Params: 28.37 M
    ==============================
    

    when use the fvcore.nn.flop_count(attention will be included), I get:

    ==============================
    Input shape: (3, 512, 2048)
    Flops: 225.98693683200003
    Params: 28372862
    ==============================
    
    opened by tonysy 1
  • Questions about table 5

    Questions about table 5

    Hi,

    In your paper table 5, the (G,G,G,G) uses the numbers (79.8%) from PVT paper, which uses absolution positional encoding. However, I suppose the other model variants listed in this table use CPE, so they are not directly comparable. Should the accuracy of (G,G,G,G) with CPE be 81.2% as shown in table 1?

    In general, I am interested in knowing if there is a benifit of using global attention in the early layers.

    Thanks.

    opened by kaikai23 0
  • TypeError: __init__() takes from 3 to 10 positional arguments but 11 were given

    TypeError: __init__() takes from 3 to 10 positional arguments but 11 were given

    when i use the config 'mask_rcnn_alt_gvt_s_fpn_1x_coco_pvt_setting.py' do detection project, the following error will appear: 1 Can you give me relevant tips or solutions,Thanks!

    opened by 15129302710 2
  • The difference between PEG and PEG for detection

    The difference between PEG and PEG for detection

    Hi, thanks for your great work. I meet some confusion when reading the paper. Specifically, in the part of Supplement C. Example Code, there is a light difference between the two presented Algorithms. In Algorithm 1 PyTorch snippet of PEG, the PEG includes a Conv layer, while in Algorithm 2PyTorch snippet of PEG for detection, there are additional BN+Relu layers. I wonder how about the effectiveness comparison of this two setting, would the second setting with BN+Relu be better? Thank you.

    opened by mt-cly 0
  • 关于学习率

    关于学习率

    您好,我在复现论文的过程中,发现在pvpvt_s.txt 学习率最大为0.00125

    {"train_lr": 1.000000000000015e-06, "train_loss": 6.913535571038723, "test_loss": 6.8714314655021385, "test_acc1": 0.2160000117301941, "test_acc5": 1.2940000837326049, "epoch": 0, "n_parameters": 24106216} {"train_lr": 1.000000000000015e-06, "train_loss": 6.896081995010376, "test_loss": 6.839652865021317, "test_acc1": 0.40800002346038816, "test_acc5": 1.7700001041412354, "epoch": 1, "n_parameters": 24106216} {"train_lr": 0.0002507999999999969, "train_loss": 6.628805226147175, "test_loss": 5.555466687237775, "test_acc1": 6.462000321006775, "test_acc5": 18.384000979614257, "epoch": 2, "n_parameters": 24106216} {"train_lr": 0.0005006000000000066, "train_loss": 6.272622795701027, "test_loss": 4.64223074471509, "test_acc1": 15.328000774383545, "test_acc5": 34.724001724243166, "epoch": 3, "n_parameters": 24106216} {"train_lr": 0.0007504000000000098, "train_loss": 5.958464104115963, "test_loss": 3.9152780064830073, "test_acc1": 24.868001231384277, "test_acc5": 48.68200244750977, "epoch": 4, "n_parameters": 24106216} {"train_lr": 0.0010002000000000064, "train_loss": 5.670980889737606, "test_loss": 3.3432016747969167, "test_acc1": 33.47600182495117, "test_acc5": 59.10200291748047, "epoch": 5, "n_parameters": 24106216} {"train_lr": 0.0012491503115478462, "train_loss": 5.421633140593767, "test_loss": 2.9919017029029353, "test_acc1": 39.19200199279785, "test_acc5": 65.3920035522461, "epoch": 6, "n_parameters": 24106216} {"train_lr": 0.0012487765716255204, "train_loss": 5.202176849722862, "test_loss": 2.6157535275927297, "test_acc1": 45.32800238342285, "test_acc5": 71.13600388793945, "epoch": 7, "n_parameters": 24106216}

    请问是根据batchsize进行缩放得到的学习率还是通过调整得到的学习率 非常感谢

    opened by Holidays1999 2
  • 关于mmdet版本问题

    关于mmdet版本问题

    我安装了2.8.0的mmdet 但在执行 config_file = 配置文件路径 checkpoint_file = 模型路径 model = init_detector(config_file, checkpoint_file, device='cuda:0') 加载模型时,却得到了这样的报错:

    Traceback (most recent call last): File "newtest.py", line 17, in model = init_detector(config_file, checkpoint_file, device='cuda:0') File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/apis/inference.py", line 38, in init_detector model = build_detector(config.model, test_cfg=config.test_cfg) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 67, in build_detector return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 32, in build return build_from_cfg(cfg, registry, default_args) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmcv/utils/registry.py", line 171, in build_from_cfg return obj_cls(**args) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/detectors/mask_rcnn.py", line 24, in init pretrained=pretrained) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/detectors/two_stage.py", line 26, in init self.backbone = build_backbone(backbone) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 37, in build_backbone return build(cfg, BACKBONES) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 32, in build return build_from_cfg(cfg, registry, default_args) File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmcv/utils/registry.py", line 164, in build_from_cfg f'{obj_type} is not in the {registry.name} registry') KeyError: 'alt_gvt_small is not in the backbone registry'

    然后gvt.py里确实是有注册成backbone的,不知道如何解决

    opened by sayoko17 3
Owner
null
Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv 官方原版代

Hongtao Wen 4 Jan 15, 2022
Implementation of Barlow Twins paper

barlowtwins PyTorch Implementation of Barlow Twins paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction This is currently a work in

IgorSusmelj 86 Dec 20, 2022
PyTorch implementation of Barlow Twins.

Barlow Twins: Self-Supervised Learning via Redundancy Reduction PyTorch implementation of Barlow Twins. @article{zbontar2021barlow, title={Barlow Tw

Facebook Research 839 Dec 29, 2022
Barlow Twins and HSIC

Barlow Twins and HSIC Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet). Correspon

Yao-Hung Hubert Tsai 49 Nov 24, 2022
Exploring whether attention is necessary for vision transformers

Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet Paper/Report TL;DR We replace the attention layer in a v

Luke Melas-Kyriazi 461 Jan 7, 2023
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

null 73 Jan 1, 2023
Official code for "Focal Self-attention for Local-Global Interactions in Vision Transformers"

Focal Transformer This is the official implementation of our Focal Transformer -- "Focal Self-attention for Local-Global Interactions in Vision Transf

Microsoft 486 Dec 20, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
The project is an official implementation of our paper "3D Human Pose Estimation with Spatial and Temporal Transformers".

3D Human Pose Estimation with Spatial and Temporal Transformers This repo is the official implementation for 3D Human Pose Estimation with Spatial and

Ce Zheng 363 Dec 28, 2022
This repository is the code of the paper "Sparse Spatial Transformers for Few-Shot Learning".

?? Sparse Spatial Transformers for Few-Shot Learning This code implements the Sparse Spatial Transformers for Few-Shot Learning(SSFormers). Our code i

chx_nju 38 Dec 13, 2022
The open source code of SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation.

SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation(ICPR 2020) Overview This code is for the paper: Spatial Attention U-Net for Retinal V

Changlu Guo 151 Dec 28, 2022
Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving

GSAN Introduction Code for paper GSAN: Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving, wh

YE Luyao 6 Oct 27, 2022
A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022)

A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022) https://arxiv.org/abs/2203.09388 Jianqi Ma, Zheto

MA Jianqi, shiki 104 Jan 5, 2023
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 87 Nov 29, 2022
Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations. [2021]

Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations This repo contains the Pytorch implementation of our paper: Revisit

Wouter Van Gansbeke 80 Nov 20, 2022
Official Code for ICML 2021 paper "Revisiting Point Cloud Shape Classification with a Simple and Effective Baseline"

Revisiting Point Cloud Shape Classification with a Simple and Effective Baseline Ankit Goyal, Hei Law, Bowei Liu, Alejandro Newell, Jia Deng Internati

Princeton Vision & Learning Lab 115 Jan 4, 2023
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
an implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation using PyTorch

revisiting-sepconv This is a reference implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation [1] using PyTorch. Given two f

Simon Niklaus 59 Dec 22, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022