a general-purpose Transformer based vision backbone

Overview

Swin Transformer

By Ze Liu*, Yutong Lin*, Yue Cao*, Han Hu*, Yixuan Wei, Zheng Zhang, Stephen Lin and Baining Guo.

This repo is the official implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". The code will be coming soon.

Introduction

Swin Transformer is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (86.4 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val).

Citing Swin Transformer

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

Comments
  • Hi, I am a little bit confused about cyclic shift,Can you help me understand?

    Hi, I am a little bit confused about cyclic shift,Can you help me understand?

    Can you explain how the cyclic shift changes the feature map, and what position of the tokens is masked during the calculation of the attention? As in your paper's figure , it's too abstract for me. In your code, you use torch.roll() to implemented cyclic shift, and then From Line 209 To Line 227 you calculate the mask, How the mask help to compute the attention?

    opened by meiguoofa 11
  • The Question about the mask of window attention

    The Question about the mask of window attention

    Nice work!And i reading your code recently. But i cannot understand well about the implementation of the mask in shifted window attention.

    I simply draw a picture like below. The red mean the mask, and i choose windowsize as 2, shiftsize as 1.

    I think the mask should be like this image but i use your code to generate mask like this:

    import torch
    import torch.nn as nn
    
    
    def window_partition(x, window_size):
        """
        Args:
            x: (B, H, W, C)
            window_size (int): window size
    
        Returns:
            windows: (num_windows*B, window_size, window_size, C)
        """
        B, H, W, C = x.shape
        x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
        return windows
    
    
    window_size = 2
    shift_size = 1
    H, W = 4, 4
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    
    cnt = 0
    for h in h_slices:
    for w in w_slices:
    img_mask[:, h, w, :] = cnt
    cnt += 1
    
    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
    mask_windows = mask_windows.view(-1, window_size * window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    attn_mask = attn_mask.unsqueeze(1).unsqueeze(0)
    
    """
    tensor([[[[[   0.,    0.,    0.,    0.],
               [   0.,    0.,    0.,    0.],
               [   0.,    0.,    0.,    0.],
               [   0.,    0.,    0.,    0.]]],
    
    
             [[[   0., -100.,    0., -100.],
               [-100.,    0., -100.,    0.],
               [   0., -100.,    0., -100.],
               [-100.,    0., -100.,    0.]]],
    
    
             [[[   0.,    0., -100., -100.],
               [   0.,    0., -100., -100.],
               [-100., -100.,    0.,    0.],
               [-100., -100.,    0.,    0.]]],
    
    
             [[[   0., -100., -100., -100.],
               [-100.,    0., -100., -100.],
               [-100., -100.,    0., -100.],
               [-100., -100., -100.,    0.]]]]])
    """
    

    I cannot understand it, can you give me a favor?

    opened by MARD1NO 9
  • Eval acc is 0, when I set `--amp-opt-level` as `O0`

    Eval acc is 0, when I set `--amp-opt-level` as `O0`

    Hi, i download the models from get_start.md and want to eval on ImageNet1k.

    • The eval acc is 0, which is as follows.
    [2021-04-13 18:15:43 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [0/98]	Time 5.697 (5.697)	Loss 9.3819 (9.3819)	Acc@1 0.000 (0.000)	Acc@5 0.586 (0.586)	Mem 2502MB
    [2021-04-13 18:15:47 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [10/98]	Time 0.285 (0.893)	Loss 9.3991 (9.4262)	Acc@1 0.000 (0.018)	Acc@5 0.391 (0.178)	Mem 2503MB
    [2021-04-13 18:15:50 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [20/98]	Time 0.554 (0.638)	Loss 9.4262 (9.4286)	Acc@1 0.195 (0.028)	Acc@5 0.391 (0.270)	Mem 2503MB
    [2021-04-13 18:15:53 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [30/98]	Time 0.472 (0.535)	Loss 9.3771 (9.4292)	Acc@1 0.195 (0.063)	Acc@5 0.391 (0.290)	Mem 2503MB
    [2021-04-13 18:15:57 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [40/98]	Time 0.507 (0.490)	Loss 9.4310 (9.4236)	Acc@1 0.195 (0.067)	Acc@5 0.586 (0.286)	Mem 2503MB
    [2021-04-13 18:16:00 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [50/98]	Time 0.299 (0.458)	Loss 9.4321 (9.4172)	Acc@1 0.000 (0.092)	Acc@5 0.195 (0.341)	Mem 2503MB
    [2021-04-13 18:16:03 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [60/98]	Time 0.197 (0.436)	Loss 9.4335 (9.4172)	Acc@1 0.195 (0.090)	Acc@5 0.195 (0.336)	Mem 2503MB
    [2021-04-13 18:16:07 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [70/98]	Time 0.235 (0.420)	Loss 9.4177 (9.4207)	Acc@1 0.000 (0.091)	Acc@5 0.391 (0.322)	Mem 2503MB
    [2021-04-13 18:16:10 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [80/98]	Time 0.240 (0.408)	Loss 9.3358 (9.4199)	Acc@1 0.586 (0.096)	Acc@5 0.586 (0.323)	Mem 2503MB
    [2021-04-13 18:16:13 swin_tiny_patch4_window7_224](main.py 266): INFO Test: [90/98]	Time 0.232 (0.396)	Loss 9.3683 (9.4161)	Acc@1 0.195 (0.097)	Acc@5 0.391 (0.324)	Mem 2503MB
    [2021-04-13 18:16:15 swin_tiny_patch4_window7_224](main.py 272): INFO  * Acc@1 0.096 Acc@5 0.318
    [2021-04-13 18:16:15 swin_tiny_patch4_window7_224](main.py 121): INFO Accuracy of the network on the 50000 test images: 0.1%
    
    • The shell is as follows.
    python3.7 -m torch.distributed.launch \
        --nproc_per_node 4 \
        --master_port 12345 \
        main.py \
            --eval \
            --cfg="configs/swin_tiny_patch4_window7_224.yaml"  \
            --resume="./swin_tiny_patch4_window7_224.pth" \
            --data-path="/data/ILSVRC2012"
    
    • The only diff is

    image

    • Could you please help me see why it does not work? thanks!
    opened by littletomatodonkey 9
  • Datasets

    Datasets

    How can I get the following dataset?

    data └── ImageNet-Zip ├── train_map.txt ├── train.zip ├── val_map.txt └── val.zip

    We are now using the standard folder dataset, the speech of which is slow, about 2 hours/ 1 epoch.

    opened by vvhj 8
  • RuntimeError: expected a single top-level function

    RuntimeError: expected a single top-level function

    File "/home/fumy/anaconda3/envs/CloserLookFewShot/lib/python3.6/site-packages/torch/jit/frontend.py", line 144, in get_jit_ast raise RuntimeError("expected a single top-level function") Is there anybody meet this question?

    opened by fmy7834 6
  • About Training Time

    About Training Time

    Hi, thanks a lot for your amazing work! I have tried to train Swin-Transformer from scratch on ImageNet1k, but it took me quite a long time, so I wonder how long did it take you to train the base model on ImageNet 1k and what kind of GPUs did you use?

    opened by wangjk666 5
  • APEX Gradient overflow

    APEX Gradient overflow

    When I use the O1 train the swin-net, but Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 131072.0 [2021-04-25 05:35:19 swin_base_patch4_window7_224](main_prune.py 310): INFO Train: [0/300][4050/5004] eta 0:17:06 lr 0.000500 time 1.1737 (1.0765) loss 3.2572 (3.3279) grad_norm 1.0323 (nan) mem 4814MB

    Is this normal?

    opened by vvhj 5
  • Training will not start after initialization

    Training will not start after initialization

    Hi, I'm trying to reproduce the classification training results.

    I used this command to run training, on 4x A100 GPUs:

    CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env  --master_port 12345 \
        main.py --local_rank 0 \
        --data-path <path> \
        --cfg configs/swin_tiny_patch4_window7_224.yaml
    

    I got to these lines of log:

    INFO number of params: 28288354
    INFO number of GLOPS: 4.49
    All checkpoints founded in <...>: []
    INFO no checkpoint found in <...>, ignoring auto resume
    INFO Start training
    

    And it's stuck there, for a really long time. I reproduced this exact behaviour multiple times, without ever getting any further. Anyone's seen this or know how to fix it? Many thanks.

    opened by felix-do-wizardry 4
  • Inferior performance of EMA model

    Inferior performance of EMA model

    I was trying to train swin transformer on a object detection task and used EMA mechanism to stabilize the training process. But I found that the EMA model of swin-t had inferior performance compared with original model. I am wondering why this could happen as the EMA performance is much better with other conv nets. Here is the code of how I implement EMA:

    def build_ema(config):
        model = build_model(config)
        for param in model.parameters():
            param.detach_()
        return model
    
    for ema_v, model_v in zip(ema_model.state_dict().values(), model.module.state_dict().values()):
        ema_v.copy_(ema_v * 0.999 + (1 - 0.999) * model_v)
    
    opened by syorami 4
  • If I can train out the performance in the paper with the default configuration?

    If I can train out the performance in the paper with the default configuration?

    Hello! I have tried the default config with swin_tiny_224 to train without any pre-training model. And in the 65-th epoches, the top1 get to 71.3. I wonder whether I can get the performance like 81.5 as the paper say ? After all, it really takes too long time to train.

    opened by starmemda 4
  • Accuracy curve for each epoch

    Accuracy curve for each epoch

    Thanks for providing the excellent work. I am trying to reproduce the results of Swin-T model. Could you please provide a accuracy curve w.r.t epochs, which can serve as a reference for me to validate the correctness of codes and experiments.

    Thanks a lot!

    opened by WangFeng18 4
  • swin transformer does not support dynamic input shape after tracing model

    swin transformer does not support dynamic input shape after tracing model

    hello there, I am trying to convert swin model to onnx and then tensorrt, but a problem which I face is that it does not support dynamic input resolution after tracing the model using torch.jit.trace. It seems that it is because of the mask input in window attention. Do you have any idea how I can fix this problem?

    opened by fatemebafghi 0
  • Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Dear researchers, please also consider checking our unet implementation of the proposed focal modulation block for small datasets problems, specially in medical domains. Our model shows the state of the art result camparing to base swin-unet and transunet models in terms of dice score on synapse dataset.

    https://github.com/givkashi/Focal-Unet

    https://www.researchgate.net/publication/366423296_Focal-UNet_UNet-like_Focal_Modulation_for_Medical_Image_Segmentation

    Thanks.

    opened by mohammadrezanaderi4 0
  • Update README.md for description of FasterTransformer

    Update README.md for description of FasterTransformer

    Dear authors of Swin Transfromer:

    As we discussed offline, FasterTransformer supports swin v2 now. So it will be great if we can update the description of FasterTransformer in the README. Thank you.
    
    opened by Jackch-NV 0
  • CLS token

    CLS token

    Thank you for the great paper and code repo, super nice idea.

    You mention in the paper that you experiment with appending a CLS token and using this to perform classification. I was wondering how you treat this CLS token - does it attend to all patches or just just the patches which fall into its local area (in the swin self attention process)? I also cannot find where this is implemented in code as this would be helpful.

    Many thanks, Harry

    opened by harrygcoppock 1
  • Could you please provide the detailed config of SwinV2-Small used during simMIM pretraining

    Could you please provide the detailed config of SwinV2-Small used during simMIM pretraining

    Hello,

    I would like to resume the pretraining process of SwinV2-Small using my own custom data for real-world application, but I haven't find the detailed pretraining configuration. Would you mind providing it?

    Any help will be highly appreciated!

    opened by assiduous006 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
LIAO Shuiying 6 Dec 1, 2022
Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

MobileViT RegNet Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TR

Hong-Jia Chen 91 Dec 2, 2022
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

AI2 79 Dec 23, 2022
Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Self-attention building blocks for computer vision applications in PyTorch Implementation of self attention mechanisms for computer vision in PyTorch

AI Summer 962 Dec 23, 2022
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
Alex Pashevich 62 Dec 24, 2022
BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation This is a demo implementation of BYOL for Audio (BYOL-A), a self-sup

NTT Communication Science Laboratories 160 Jan 4, 2023
ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection

ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection This repository contains implementation of the

Visual Understanding Lab @ Samsung AI Center Moscow 190 Dec 30, 2022
A general-purpose, flexible, and easy-to-use simulator alongside an OpenAI Gym trading environment for MetaTrader 5 trading platform (Approved by OpenAI Gym)

gym-mtsim: OpenAI Gym - MetaTrader 5 Simulator MtSim is a simulator for the MetaTrader 5 trading platform alongside an OpenAI Gym environment for rein

Mohammad Amin Haghpanah 184 Dec 31, 2022
General purpose GPU compute framework for cross vendor graphics cards (AMD, Qualcomm, NVIDIA & friends)

General purpose GPU compute framework for cross vendor graphics cards (AMD, Qualcomm, NVIDIA & friends). Blazing fast, mobile-enabled, asynchronous and optimized for advanced GPU data processing usecases. Backed by the Linux Foundation.

The Kompute Project 1k Jan 6, 2023
A general-purpose programming language, focused on simplicity, safety and stability.

The Rivet programming language A general-purpose programming language, focused on simplicity, safety and stability. Rivet's goal is to be a very power

The Rivet programming language 17 Dec 29, 2022
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

null 102 Dec 6, 2022
(ImageNet pretrained models) The official pytorch implemention of the TPAMI paper "Res2Net: A New Multi-scale Backbone Architecture"

Res2Net The official pytorch implemention of the paper "Res2Net: A New Multi-scale Backbone Architecture" Our paper is accepted by IEEE Transactions o

Res2Net Applications 928 Dec 29, 2022
yolox_backbone is a deep-learning library and is a collection of YOLOX Backbone models.

YOLOX-Backbone yolox-backbone is a deep-learning library and is a collection of YOLOX backbone models. Install pip install yolox-backbone Load a Pret

Yonghye Kwon 21 Dec 28, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements 环境安装 pip install -r requirements.txt Evaluation metric Visdrone Model mAP mAP@50 Parameters(M) GFLOPs FPS@

ZhangYuan 719 Jan 2, 2023
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 8, 2022
PyTorch Implementation of Backbone of PicoDet

PicoDet-Backbone PyTorch Implementation of Backbone of PicoDet Original Implementation is implemented on PaddlePaddle. Example picodet_l_backbone = ES

Yonghye Kwon 7 Jul 12, 2022