The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

Overview

CrossFormer

This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

Introduction

Existing vision transformers fail to build attention among objects/features of different scales (cross-scale attention), while such ability is very important to visual tasks. CrossFormer is a versatile vision transformer which solves this problem. Its core designs contain Cross-scale Embedding Layer (CEL), Long-Short Distance Attention (L/SDA), which work together to enable cross-scale attention.

CEL blends every input embedding with multiple-scale features. L/SDA split all embeddings into several groups, and the self-attention is only computed within each group (embeddings with the same color border belong to the same group.).

Further, we also propose a dynamic position bias (DPB) module, which makes the effective yet inflexible relative position bias apply to variable image size.

Now, experiments are done on four representative visual tasks, i.e., image classification, objection detection, and instance/semantic segmentation. Results show that CrossFormer outperforms existing vision transformers in these tasks, especially in dense prediction tasks (i.e., object detection and instance/semantic segmentation). We think it is because image classification only pays attention to one object and large-scale features, while dense prediction tasks rely more on cross-scale attention.

Prerequisites

  1. Libraries (Python3.6-based)
pip3 install numpy scipy Pillow pyyaml torch==1.7.0 torchvision==0.8.1 timm==0.3.2
  1. Dataset: ImageNet

  2. Requirements for detection/instance segmentation and semantic segmentation are listed here: detection/README.md or segmentation/README.md

Getting Started

Training

## There should be two directories under the path_to_imagenet: train and validation

## CrossFormer-T
python -u -m torch.distributed.launch --nproc_per_node 8 main.py --cfg configs/tiny_patch4_group7_224.yaml \
--batch-size 128 --data-path path_to_imagenet --output ./output

## CrossFormer-S
python -u -m torch.distributed.launch --nproc_per_node 8 main.py --cfg configs/small_patch4_group7_224.yaml \
--batch-size 128 --data-path path_to_imagenet --output ./output

## CrossFormer-B
python -u -m torch.distributed.launch --nproc_per_node 8 main.py --cfg configs/base_patch4_group7_224.yaml 
--batch-size 128 --data-path path_to_imagenet --output ./output

## CrossFormer-L
python -u -m torch.distributed.launch --nproc_per_node 8 main.py --cfg configs/large_patch4_group7_224.yaml \
--batch-size 128 --data-path path_to_imagenet --output ./output

Testing

## Take CrossFormer-T as an example
python -u -m torch.distributed.launch --nproc_per_node 1 main.py --cfg configs/tiny_patch4_group7_224.yaml \
--batch-size 128 --data-path path_to_imagenet --eval --resume path_to_crossformer-t.pth

Training scripts for objection detection: detection/README.md.

Training scripts for semantic segmentation: segmentation/README.md.

Results

Image Classification

Models trained on ImageNet-1K and evaluated on its validation set. The input image size is 224 x 224.

Architectures Params FLOPs Accuracy Models
ResNet-50 25.6M 4.1G 76.2% -
RegNetY-8G 39.0M 8.0G 81.7% -
CrossFormer-T 27.8M 2.9G 81.5% Google Drive/BaiduCloud, key: nkju
CrossFormer-S 30.7M 4.9G 82.5% Google Drive/BaiduCloud, key: fgqj
CrossFormer-B 52.0M 9.2G 83.4% Google Drive/BaiduCloud, key: 7md9
CrossFormer-L 92.0M 16.1G 84.0% TBD

More results compared with other vision transformers can be seen in the paper.

Objection Detection & Instance Segmentation

Models trained on COCO 2017. Backbones are initialized with weights pre-trained on ImageNet-1K.

Backbone Detection Head Learning Schedule Params FLOPs box AP mask AP
ResNet-101 RetinaNet 1x 56.7M 315.0G 38.5 -
CrossFormer-S RetinaNet 1x 40.8M 282.0G 44.4 -
CrossFormer-B RetinaNet 1x 62.1M 389.0G 46.2 -
ResNet-101 Mask-RCNN 1x 63.2M 336.0G 40.4 36.4
CrossFormer-S Mask-RCNN 1x 50.2M 301.0G 45.4 41.4
CrossFormer-B Mask-RCNN 1x 71.5M 407.9G 47.2 42.7

More results and pretrained models for objection detection: detection/README.md.

Semantic Segmentation

Models trained on ADE20K. Backbones are initialized with weights pre-trained on ImageNet-1K.

Backbone Segmentation Head Iterations Params FLOPs IOU MS IOU
CrossFormer-S FPN 80K 34.3M 209.8G 46.4 -
CrossFormer-B FPN 80K 55.6M 320.1G 48.0 -
CrossFormer-L FPN 80K 95.4M 482.7G 49.1 -
ResNet-101 UPerNet 160K 86.0M 1029.G 44.9 -
CrossFormer-S UPerNet 160K 62.3M 979.5G 47.6 48.4
CrossFormer-B UPerNet 160K 83.6M 1089.7G 49.7 50.6
CrossFormer-L UPerNet 160K 125.5M 1257.8G 50.4 51.4

MS IOU means IOU with multi-scale testing.

More results and pretrained models for semantic segmentation: segmentation/README.md.

Citing Us

@article{crossformer2021,
  title     = {CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention},
  author    = {Wenxiao Wang and Lu Yao and Long Chen and Deng Cai and Xiaofei He and Wei Liu},
  journal   = {CoRR},
  volume    = {abs/2108.00154},
  year      = {2021},
}

Acknowledgement

Part of the code of this repository refers to Swin Transformer.

Comments
  • Some question about your paper and code

    Some question about your paper and code

    Hi,I'm very interested in your work about Multi-scale Attention in Transformer. but I have some questions about your work:

    1. In Appendix 2. DPB, Why do i and j parameters range from 0 to 2G-1 instead of 0 to G-1?Besides,the inputs of DPB module is (1-G+i, 1-G+j), What is the reason for this setting? Why not just use i and j as inputs?

    2. When I debug your code , I add a parameters due to I have only one 3090 with 24G memory, like this:

    parser = argparse.ArgumentParser('CrossFormer training and evaluation script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", default='/configs/small_patch4_group7_224.yaml', help='path to config file') parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+' ) # easy config modification parser.add_argument('--batch-size', type=int, default=32, help="batch size for single GPU") parser.add_argument('--data-set', type=str, default='flower', help='dataset to use') parser.add_argument('--data-path', type=str, help='path to dataset', default='/media/data2/huzhen/flower_data') parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], help='no: no cache, ' 'full: cache all data, ' 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') parser.add_argument('--resume', help='resume from checkpoint', default='') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--amp-opt-level', type=str, default='native', choices=['native', 'O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used') parser.add_argument('--output', default='./Flower_weights', type=str, metavar='PATH', help='root of output folder, the full path is /<model_name>/ (default: output)') parser.add_argument('--tag', help='tag of experiment') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--throughput', action='store_true', help='Test throughput only') parser.add_argument('--num_workers', type=int, default=8, help="") parser.add_argument('--mlp_ratio', type=int, default=4, help="") parser.add_argument('--warmup_epochs', type=int, default=20, help="#epoches for warm up") parser.add_argument("--local_rank", type=int, required=True, default=0, help='local rank for DistributedDataParallel') parser.add_argument('--device', default='cuda:2', help='device to use for training / testing')

    args, unparsed = parser.parse_known_args()
    

    but its report an error: 发生异常: SystemExit 2 The above is my parameter setting. Is there a problem? I sincerely hope I can receive for your help!

    opened by Huzhen757 6
  • KeyError:

    KeyError: "EncoderDecoder: 'CrossFormer_L is not in the backbone registry'"

    when i run the program, there is a mistake:

    KeyError: "EncoderDecoder: 'CrossFormer_L is not in the backbone registry'"

    how can i registry this?

    opened by yangyang117 5
  • 关于LSDA与CEL设计的疑问

    关于LSDA与CEL设计的疑问

    大神好,非常感谢你们的作品。有几个小疑问:

    1. 貌似LDA和SDA是交替使用的,如果调整S和L的比例或顺序是否会对结果有影响。比如一个stage中SLS或SSLLL这样的。
    2. 我看G貌似一直都是7,如果在金字塔结构中,将G逐步变为7,5,3,1(vanilla attention),不知对效果影响如何
    3. CEL放在最开始可以理解为提取多尺度信息。但随着层数的加深,可能H*W这个维度的空间位置意义越来越稀薄,那么再去提取多尺度信息可能很难用多尺度空间信息去解释了。不知后面不再加[2,4]的CEL是否有显著影响?
    4. 抛开代码实现的整洁问题,kernel=32是不是太大了,换成kernel=3的堆叠应该不影响吧

    非常感谢

    opened by kpmokpmo 4
  • Validation accuracy keeps to be 0.09% during training

    Validation accuracy keeps to be 0.09% during training

    Dear authors,

    I'm interested in your paper and perfom training from scratch on ImageNet. However, the validation accuracy keeps to be * Acc@1 0.090 during training.

    Do you have any idea why this happens? I train Swin Transformer, it works.

    I use Pytorch 1.7.1 and 1.6.0, no mixed precision, 100 epochs.

    --amp-opt-level O0 --output ./output --opts TRAIN.EPOCHS 100

    Thanks, Eddie

    opened by edizhuang 3
  • Some questions about your last paper

    Some questions about your last paper

    Dear author, I have read your article recently:'Accelerate CNNs from Three Dimensions: A Comprehensive Pruning Framework', and I am particularly interested in your article. Do you have an open source plan? Thank you for your answer.

    opened by 764483 2
  • Some wrongs with the pre-trained model crossformer-b.pth

    Some wrongs with the pre-trained model crossformer-b.pth

    Hi, thanks for your great work. I am using your crossformer_base as my backbone network for downstream tracking tasks. But now when I load your pre-trained model, a very correct Unexpected key(s) appears. My loading code is as follow: ckpt = torch.load(ckpt_path, map_location='cpu') missing_keys, unexpected_keys = backbone.body.load_state_dict(ckpt['model'], strict=False)

    The result as follow: unexpected keys: ['norm.weight', 'norm.bias', 'head.weight', 'head.bias', 'layers.0.blocks.0.attn.biases', 'layers.0.blocks.0.attn.relative_position_index', 'layers.0.blocks.1.attn.biases', .....

    opened by hongsheng-Z 2
  • Does CrossFormer require a fixed input size?

    Does CrossFormer require a fixed input size?

    Hi there and thanks for the nice work. I'm currently trying to use CrossFormer_B as my backbone in detection/instance_segmentation. I've noticed that we need to define the img_size in the backbone configs. However, defining that can be limiting in the sense that we usually use cropping augmentations during training, or multi-scale inference at test time. Is there any way to keep these methods working with the current implementation?

    I'll copy the related part of my config file down here:

    model = dict( type='CascadeRCNN', pretrained=None, backbone=dict( type='CrossFormer', img_size=[3840, 1920], patch_size=[4, 8, 16, 32], in_chans=3, num_classes=7, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], group_size=[7, 7, 7, 7], crs_interval=[8, 4, 2, 1], mlp_ratio=4, qkv_bias=True, qk_scale=None, drop_rate=0.0, drop_path_rate=0.3, patch_norm=True, use_checkpoint=False, merge_size=[[2, 4], [2, 4], [2, 4]]),

    ... ...

    train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( type='Resize', img_scale=[(3840, 1080), (3840, 1560)], multiscale_mode='range', keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.0), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']) ]

    ... ...

    Thanks,

    opened by amobiny 1
  • Semantic FPN CrossFormer-S The weight file is not uploaded completely

    Semantic FPN CrossFormer-S The weight file is not uploaded completely

    main() File "./test.py", line 127, in main checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/mmcv/runner/checkpoint.py", line 522, in load_checkpoint checkpoint = _load_checkpoint(filename, map_location, logger) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/mmcv/runner/checkpoint.py", line 466, in _load_checkpoint return CheckpointLoader.load_checkpoint(filename, map_location, logger) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/mmcv/runner/checkpoint.py", line 243, in load_checkpoint return checkpoint_loader(filename, map_location) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/mmcv/runner/checkpoint.py", line 260, in load_from_local checkpoint = torch.load(filename, map_location=map_location) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/torch/serialization.py", line 594, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/torch/serialization.py", line 853, in _load result = unpickler.load() File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/torch/serialization.py", line 845, in persistent_load load_tensor(data_type, size, key, _maybe_decode_ascii(location)) File "/home/wangnan/anaconda3/envs/yolo-v5/lib/python3.6/site-packages/torch/serialization.py", line 833, in load_tensor storage = zip_file.get_storage_from_record(name, size, dtype).storage() RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading file data/2154620144: invalid header or archive is corrupted

    opened by 754467737 7
  • Crossformer for small object detect

    Crossformer for small object detect

    Hello, thank you for your work. I used crossformer for small object detection in my own dataset, and the effect was very poor. Is there any way to improve the accuracy of small target detection? Tansnks very much!

    opened by 1411030449 0
  • question  about LDA and SDA for irregular feature map

    question about LDA and SDA for irregular feature map

    Thank you very much for your careful reply about LDA and SDA ,but I have another question about LDA and SDA for irregular feature map。

    In your paper,the LDA and SDA used for regular input image size,like 224x224 or 384x384. So the group size is default 7, and the I is set (8, 4, 2, 1) . And for Stage-1, the I = 8, because of the need to meet GxI = feature map width/height(56x56).

    However, for irregular feature map size, for example, 80 x 134, now for the group size and interval ,It seems that we can no longer design as mentioned in the paper。If the group size is 7,it's need to padding feature map to apply the group size,then the feature map size is become 84x140, the feature map reshape to [W_nG, G, H_nG, G], SDA here can be executed normally。 but for next LDA,how to set the interval I ?Besides,the feature map is irregular,so the interval I is different for width and height。How can I reasonably set the parameter interval I ?

    Can you give me some advice about this question?Thanks!

    opened by Huzhen757 0
Owner
cheerss
cheerss
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
Alex Pashevich 62 Dec 24, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

null 100 Dec 15, 2022
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 409 Jan 6, 2023
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
CVPR 2021 Official Pytorch Code for UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training

UC2 UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training Mingyang Zhou, Luowei Zhou, Shuohang Wang, Yu Cheng, Linjie Li, Zhou Yu,

Mingyang Zhou 28 Dec 30, 2022
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

null 910 Dec 28, 2022
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

IIGROUP 6 Sep 21, 2022
The code for our paper submitted to RAL/IROS 2022: OverlapTransformer: An Efficient and Rotation-Invariant Transformer Network for LiDAR-Based Place Recognition.

OverlapTransformer The code for our paper submitted to RAL/IROS 2022: OverlapTransformer: An Efficient and Rotation-Invariant Transformer Network for

HAOMO.AI 136 Jan 3, 2023
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question IntentionClassification Benchmark for Text-to-SQL"

TriageSQL The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question Intention Classification Benchmark for Text

Yusen Zhang 22 Nov 9, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

null 43 Nov 21, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Wang jiahao 3 Oct 31, 2022