CoaT: Co-Scale Conv-Attentional Image Transformers

Deep Learning CoaT

This repository contains the official code and pretrained models for CoaT: Co-Scale Conv-Attentional Image Transformers. It introduces (1) a co-scale mechanism to realize fine-to-coarse, coarse-to-fine and cross-scale attention modeling and (2) an efficient conv-attention module to realize relative position encoding in the factorized attention.

Model Accuracy

For more details, please refer to CoaT: Co-Scale Conv-Attentional Image Transformers by Weijian Xu*, Yifan Xu*, Tyler Chang, and Zhuowen Tu.


04/23/2021: Pre-trained checkpoint for CoaT-Lite Mini is released.
04/22/2021: Code and pre-trained checkpoint for CoaT-Lite Tiny are released.


Environment Preparation

  1. Set up a new conda environment and activate it.

    # Create an environment with Python 3.8.
    conda create -n coat python==3.8
    conda activate coat
  2. Install required packages.

    # Install PyTorch 1.7.1 w/ CUDA 11.0.
    pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f
    # Install timm 0.3.2.
    pip install timm==0.3.2
    # Install einops.
    pip install einops

Code and Dataset Preparation

  1. Clone the repo.

    git clone
    cd CoaT
  2. Download ImageNet dataset (ILSVRC 2012) and extract.

    # Create dataset folder.
    mkdir -p ./data/ImageNet
    # Download the dataset (not shown here) and copy the files (assume the download path is in $DATASET_PATH).
    cp $DATASET_PATH/ILSVRC2012_img_train.tar $DATASET_PATH/ILSVRC2012_img_val.tar $DATASET_PATH/ILSVRC2012_devkit_t12.tar.gz ./data/ImageNet
    # Extract the dataset.
    python -c "from torchvision.datasets import ImageNet; ImageNet('./data/ImageNet', split='train')"
    python -c "from torchvision.datasets import ImageNet; ImageNet('./data/ImageNet', split='val')"
    # After the extraction, you should observe `train` and `val` folders under ./data/ImageNet.

Evaluate Pre-trained Checkpoint

We provide the CoaT checkpoints pre-trained on the ImageNet dataset.

Name Acc@1 Acc@5 #Params SHA-256 (first 8 chars) URL
CoaT-Lite Tiny 77.5 93.8 5.7M e88e96b0 model, log
CoaT-Lite Mini 79.1 94.5 11M 6b4a8ae5 model, log

The following commands provide an example (CoaT-Lite Tiny) to evaluate the pre-trained checkpoint.

# Download the pretrained checkpoint.
mkdir -p ./output/pretrained
wget -P ./output/pretrained
sha256sum ./output/pretrained/coat_lite_tiny_e88e96b0.pth  # Make sure it matches the SHA-256 hash (first 8 characters) in the table.

# Evaluate.
# Usage: bash ./scripts/ [model name] [output folder] [checkpoint path]
bash ./scripts/ coat_lite_tiny coat_lite_tiny_pretrained ./output/pretrained/coat_lite_tiny_e88e96b0.pth
# It should output results similar to "Acc@1 77.504 Acc@5 93.814" at very last.


The following commands provide an example (CoaT-Lite Tiny, 8-GPU) to train the CoaT model.

# Usage: bash ./scripts/ [model name] [output folder]
bash ./scripts/ coat_lite_tiny coat_lite_tiny


The following commands provide an example (CoaT-Lite Tiny) to evaluate the checkpoint after training.

# Usage: bash ./scripts/ [model name] [output folder] [checkpoint path]
bash ./scripts/ coat_lite_tiny coat_lite_tiny_eval ./output/coat_lite_tiny/checkpoints/checkpoint0299.pth


      title={Co-Scale Conv-Attentional Image Transformers}, 
      author={Weijian Xu and Yifan Xu and Tyler Chang and Zhuowen Tu},


This repository is released under the Apache License 2.0. License can be found in LICENSE file.


Thanks to DeiT and pytorch-image-models for a clear and data-efficient implementation of ViT. Thanks to lucidrains' implementation of Lambda Networks and CPVT.

  • Re-producing issue

    Re-producing issue


    For checking re-producibility, I tried to train the coat_lite_mini model(reported 79.1/94.5) and got 78.85/94.42 by using this command :

    bash scripts/ coat_lite_mini coat_lite_mini

    with the default settings such as the batch size of 256 and using 8 GPUs (TITAN RTX).

    Is such a small difference (79.1 vs. 78.9) negligible?

    My environment :

    sys.platform linux Python 3.7.9 (default, Aug 31 2020, 12:42:55) [GCC 7.3.0] numpy 1.19.2 Compiler GCC 7.5 CUDA compiler CUDA 10.1 detectron2 arch flags 7.5 DETECTRON2_ENV_MODULE PyTorch 1.7.0 PyTorch debug build True GPU available True GPU 0,1,2,3,4,5,6,7 TITAN RTX (arch=7.5) CUDA_HOME /usr/local/cuda-10.1 Pillow 8.0.1 torchvision 0.8.0 torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5 fvcore 0.1.2.post20201218 cv2 Not found

    PyTorch built with:

    • GCC 7.3
    • C++ Version: 201402
    • Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
    • Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
    • OpenMP 201511 (a.k.a. OpenMP 4.5)
    • NNPACK is enabled
    • CPU capability usage: AVX2
    • CUDA Runtime 10.2
    • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
    • CuDNN 7.6.5
    • Magma 2.5.2
    • Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,
    opened by youngwanLEE 12
  • About new detection result in the camera-ready paper in ICCV

    About new detection result in the camera-ready paper in ICCV

    Hi, First of all, congratulations on your acceptance in ICCV 👍.

    I have seen your updated paper(arxiv v2) and have some questions.

    In Table 3 which shows Mask R-CNN results under MMDetection framework,

    (1) Is the result with FPN 1x is trained ms-train or single-scale train ??

    (2) Do you have a plan to release the new implementation of MMDetection?

    Thanks in advance :)

    opened by youngwanLEE 3
  • About AMP and batch size

    About AMP and batch size

    Hi, I'm very impressed by your excellent work! Thanks for sharing your code.

    I have questions about the training protocol.

    In your paper,

    "We train all models with a global batch size of 2048 with the NVIDIA Automatic Mixed Precision(AMP) enabled."

    but the training script denotes the batch size of 256, instead of 2048.

    I wonder two points from here.

    1. Can I re-produce the result accuracy in this repo by using this command (batch size=256, instead of 2048)?

    2. Does this repo contains AMP?

    Thanks in advance :)

    opened by youngwanLEE 2
  • COAT for multilabel classification

    COAT for multilabel classification

    @yix081 @xwjabc thanks for sharing the code base , i have few queries on the problem statement which i am working for its gender_age classification of a person ie multilabel recognition problem

    1. my input image size varies from 8056 to 256128 for this input image should i change the patch size from 4 to 16 if so what all other params should i change ?
    2. since it is multilabel classification problem should i change the self.head()= nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() line
    3. should i freeze the layers in the transformer and train only the last layer ?? Thanks in advance
    opened by abhigoku10 1
  • Will CoaT Small be available?

    Will CoaT Small be available?

    Dear Weijian, I recently read your paper on CoaT, it's really excellent work! I wish to do some further research based on CoaT Small. However, it's not mentioned in your paper or repo, so I wonder if you had implemented CoaT Small and if yes, will the model be made available? Thanks in advance!

    opened by encounter1997 1
  • Some questions in the paper

    Some questions in the paper

    Hi, Author. I want to know are EV and EV(hat) equivalent or approximate in the paper? Are EV(hat)^l in the second half of formula 7 and formula 8 equivalent or approximate? Thank you, looking forward to your answer.

    opened by liqiokkk 1
  • Do you have plans to release the pre-trained checkpoint for CoaT

    Do you have plans to release the pre-trained checkpoint for CoaT

    Hello, It is a nice work, I find you have released the pre-trained checkpoint for CoaT-Lite. Do you have plans to release the pre-trained checkpoint for CoaT ? Thanks ~

    opened by laojiangwei 1
  • Segmentation architecture

    Segmentation architecture

    @yix081 @xwjabc thanks for sharing the code base i have following queries

    1. Can we convert this architecture to perform segmentation task ie semanttic segmentation ? is so hwo to do it
    2. Can we convert this architecture to perform object detection ?

    Please share ur thoughts Thanks in advance

    opened by abhigoku10 5
  • Viz of attention maps

    Viz of attention maps

    @yix081 @xwjabc thanks for you work, it has helped me a lot but had few queries

    1. Can we visualize the attention maps like gradcam / cam to see how the model is learning / learned? do you have a codebase to it or can you suggest how to do it ?
    2. Coat Lite has only serial block and Coat has serial+parallel blocks but the #params Coat Liter is higher than Coat is there any specific reason for this
    3. How to reduce the #params in the CoatLite/Coat <3M drop in accuracy is acceptable Thanks in advance
    opened by abhigoku10 3
