Repository of Vision Transformer with Deformable Attention

Related tags

Deep Learning DAT
Overview

Vision Transformer with Deformable Attention

This repository contains the code for the paper Vision Transformer with Deformable Attention [arXiv].

Introduction

Deform_Attn

Deformable attention is proposed to model the relations among tokens effectively under the guidance of the important regions in the feature maps. This flexible scheme enables the self-attention module to focus on relevant regions and capture more informative features. On this basis, we present Deformable Attention Transformer (DAT), a general backbone model with deformable attention for both image classification and other dense prediction tasks.

Dependencies

  • NVIDIA GPU + CUDA 11.1
  • Python 3.8 (Recommend to use Anaconda)
  • PyTorch == 1.8.0
  • timm
  • einops
  • yacs
  • termcolor

TODO

  • Classification pretrained models.
  • Object Detection codebase & models.
  • Semantic Segmentation codebase & models.
  • CUDA operators to accelerate sampling operations.

Acknowledgement

This code is developed on the top of Swin Transformer, we thank to their efficient and neat codebase.

Citation

If you find our work is useful in your research, please consider citing:

@misc{xia2022vision,
      title={Vision Transformer with Deformable Attention}, 
      author={Zhuofan Xia and Xuran Pan and Shiji Song and Li Erran Li and Gao Huang},
      year={2022},
      eprint={2201.00520},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contact

[email protected]

Comments
  • Why set the reference point coordinates like this

    Why set the reference point coordinates like this

    Why set the reference point coordinates like this

        def _get_ref_points(self, H_key, W_key, B, dtype, device):
    
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
                torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device)
            )
            ref = torch.stack((ref_y, ref_x), -1)
            ref[..., 1].div_(W_key).mul_(2).sub_(1)
            ref[..., 0].div_(H_key).mul_(2).sub_(1)
            ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2
    
            return ref
    

    i don't understand this one ref[..., 1].div_(W_key).mul_(2).sub_(1) , specially why use .mul_(2).sub_(1)?

    opened by simplify23 4
  • unalignment of classification result on imageNet

    unalignment of classification result on imageNet

    thanks for the contribution . I trained the 224 x 224 imageNet classification model, while the acc has a gap between mine and yours. Hope there could be the pretrained model and related setting. Thanks.

    opened by Mollylulu 3
  • Changing model input size from 384 -> 1024

    Changing model input size from 384 -> 1024

    I'd like to know what layers to change if I wanted to input an image of size 1024 instead of 384. I'd also want to know if there are any additional concerns about using this model for a much bigger size input. Thanks.

    opened by tdchua 2
  • The computational cost of deformabel attention

    The computational cost of deformabel attention

    Hi, thanks for your excellent work. I notice that the number of the sampled keys/values is the same as the querys. Therefore, the computational cost of deformable attention is the same as global attention, is it right? So I'm curious why don't you use a global self-attention at the last two stages?

    opened by linjing7 2
  • 训练时计算维度出错

    训练时计算维度出错

    在使用DAT时出现了如下错误,参数是按照config设置的,与config一致,每一种方案都试过了,都会报错,麻烦帮忙看下 File "/root/BasicSR-master/basicsr/archs/discriminator_arch.py", line 202, in forward x_total = einops.rearrange(x, 'b c (r1 h1) (r2 w1) -> b (r1 r2) (h1 w1) c', h1=self.window_size[0], w1=self.window_size[1]) # B x Nr x Ws x C File "/root/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 487, in rearrange return reduce(tensor, pattern, reduction='rearrange', **axes_lengths) File "/root/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 418, in reduce raise EinopsError(message + '\n {}'.format(e)) einops.EinopsError: Error while processing rearrange-reduction pattern "b c (r1 h1) (r2 w1) -> b (r1 r2) (h1 w1) c". Input tensor shape: torch.Size([128, 96, 32, 32]). Additional info: {'h1': 7, 'w1': 7}. Shape mismatch, can't divide axis of length 32 in chunks of 7

    opened by JusticeLin 1
  • some questions about the reference points and offset network

    some questions about the reference points and offset network

    Really nice work! I have some questions about the code. I see your implementation about the conv_offset and I find you use stride of 1 so the reference points is actually the whole map. But the paper says there is a stride of r. If there is no stride larger than 1, the complexity is the same as standard MHSA even larger! I think there maybe something wrong here.

    opened by LixDemon 1
  • what does -1 mean in strides and groups variable

    what does -1 mean in strides and groups variable

    Hi, thanks for your nice work. I'm very confused with the negative values (-1) in strides and groups. For example, strides=[-1,-1,1,1] or groups=[-1.-1.3,6]. This also triggers errors when creating weight tensors due to that negative dimensions are not allowed in tensors:

    RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    It would be appreciated if you can explain what does -1 mean here.

    Thanks in advance!

    opened by Hua-YS 1
  •  Error: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    Error: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    @Vladimir2506 @Panxuran @LeapLabTHU

    I tried to use your basic DAT module, iam getting error below:

    Trying to create tensor with negative dimension -96: [-96, 1, 9, 9].

    It is because of

    https://github.com/LeapLabTHU/DAT/blob/1029c76003b346ddcc80de6293ae9c7e2b6c3565/models/dat_blocks.py#L163 which is because of https://github.com/LeapLabTHU/DAT/issues/new?permalink=https%3A%2F%2Fgithub.com%2FLeapLabTHU%2FDAT%2Fblob%2F1029c76003b346ddcc80de6293ae9c7e2b6c3565%2Fmodels%2Fdat.py%23L98

    Kindly help

    opened by ChidanandKumarKS 1
  • Different depthwise convolution kernel sizes?

    Different depthwise convolution kernel sizes?

    opened by xskxzr 1
  • RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    你好,我直接在dat.py文件中使用如下代码测试你的模型,但是报错了, model = DAT() x = torch.randn((1, 3, 48, 48)).cuda(0) y = model(x) print(y.shape)

    报错如下: RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    请问这是什么原因导致的,谢谢~

    opened by nullxjx 1
  • displacement problem

    displacement problem

    hello sir, why the displacement is calculated as below ? displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5) why not displacement = (pos.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1))

    opened by ZhuangLii 1
  • an unused parameters for class TransformerStage:

    an unused parameters for class TransformerStage:"ns_per_pt" and "sr_ratio"

    Hello,I found an unused parameters "ns_per_pt" and "sr_ratio",I was wondering what is it used for? Thank you very much!

    class TransformerStage(nn.Module): def init(self, fmap_size, window_size, ns_per_pt, dim_in, dim_embed, depths, stage_spec, n_groups, use_pe, sr_ratio, heads, stride, offset_range_factor, stage_idx, dwc_pe, no_off, fixed_pe, attn_drop, proj_drop, expansion, drop, drop_path_rate, use_dwc_mlp):

    opened by fppccc 0
  • Controlling the number of keys per query

    Controlling the number of keys per query

    In the appendix of DAT vs D-DETR, you mentioned changing the number of keys in Stage 3 and Stage 4. I was wondering where in the code, can you change for that? Thank you.

    opened by tdchua 0
  • evaluate.sh: line 6: path-to-imagenet: No such file or directory

    evaluate.sh: line 6: path-to-imagenet: No such file or directory

    when running bash evaluate.sh 1 configs/dat_tiny.yaml ./dat_tiny_in1k_224.pth, i get the following error:

    evaluate.sh: line 6: path-to-imagenet: No such file or directory
    
    opened by hygxy 0
  • Face negative dimension issue when running on CIFAR10

    Face negative dimension issue when running on CIFAR10

    Hi, I am Lukas Wang, a master's student from Columbia. I am planning to review cutting-edged VIT-based models on medium-size datasets and found your work really interesting! I was trying to run the code using CIFAT10 dataset for testing but the following error came out. RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

    I have noticed that the environment variable groups is set to groups=[-1, -1, 3, 6] as default in DAT model while the operation for DAttentionBaseline in dat_block.py will compute a negative value for first two stages. Could you please check out this issue? Really appreciate your help :)!

    opened by lukaswangbk 0
  • Deformable Attention Journal Paper not referenced

    Deformable Attention Journal Paper not referenced

    Dear Authors I had way back in May 2021 - already had published a journal article on Deformable Attention https://pubmed.ncbi.nlm.nih.gov/34022421/

    It was even published much earlier in MedArxiv in August 2020 https://www.medrxiv.org/content/10.1101/2020.08.25.20181834v1

    I would have expected that you atleast cite my paper in your journal.

    I am surprised that you had not done thorough search of prior art and report all prior work in this space of Deformable Attention (and also the reviewers of CVPR also did not thoroughly check if good prior art was done.

    Can you please cite my paper in any further publications of Deformable Attention?

    Best Regards Kumar

    opened by kumartr 0
  • About Low Accuracy

    About Low Accuracy

    Hi, thanks for your excellent work.But when I use your DAT model and pre trained weights for image segmentation tasks, the effect is not ideal. I do this: take out the features of each layer and then recover the image size through simple deconvolution up sampling and skip connection operations. The code is as follows: image

    If possible, please tell me where the error is. I hope you can publish the segmentation model as soon as possible and look forward to your reply. Thank you

    opened by z1zzzz1 0
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Alex Pashevich 62 Dec 24, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

null 100 Dec 15, 2022
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 6, 2023
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Wang jiahao 3 Oct 31, 2022
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

IIGROUP 6 Sep 21, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
This is the code for Deformable Neural Radiance Fields, a.k.a. Nerfies.

Deformable Neural Radiance Fields This is the code for Deformable Neural Radiance Fields, a.k.a. Nerfies. Project Page Paper Video This codebase conta

Google 1k Jan 9, 2023
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

null 130 Dec 5, 2022
[CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment

RADN [CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment [Paper on arXiv] Overview Update [2021/5/7] add codes for W

IIGROUP 53 Dec 28, 2022
Deformable DETR is an efficient and fast-converging end-to-end object detector.

Deformable DETR: Deformable Transformers for End-to-End Object Detection.

null 2k Jan 5, 2023
Official implementation of NPMs: Neural Parametric Models for 3D Deformable Shapes - ICCV 2021

NPMs: Neural Parametric Models Project Page | Paper | ArXiv | Video NPMs: Neural Parametric Models for 3D Deformable Shapes Pablo Palafox, Aljaz Bozic

PabloPalafox 109 Nov 22, 2022
PyTorch implementation of Deformable Convolution

Deformable Convolutional Networks in PyTorch This repo is an implementation of Deformable Convolution. Ported from author's MXNet implementation. Buil

null 411 Dec 16, 2022
Code for ICCV 2021 paper: ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators..

ARAPReg Code for ICCV 2021 paper: ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators.. Installation The cod

Bo Sun 132 Nov 28, 2022