Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

Related tags

Deep Learning CSRA
Overview

CSRA

This is the official code of ICCV 2021 paper:
Residual Attention: A Simple But Effective Method for Multi-Label Recoginition

attention

Demo, Train and Validation code have been released! (including VIT on Wider-Attribute)

This package is developed by Mr. Ke Zhu (http://www.lamda.nju.edu.cn/zhuk/) and we have just finished the implementation code of ViT models. If you have any question about the code, please feel free to contact Mr. Ke Zhu ([email protected]). The package is free for academic usage. You can run it at your own risk. For other purposes, please contact Prof. Jianxin Wu (mail to [email protected]).

Requirements

  • Python 3.7
  • pytorch 1.6
  • torchvision 0.7.0
  • pycocotools 2.0
  • tqdm 4.49.0, pillow 7.2.0

Dataset

We expect VOC2007, COCO2014 and Wider-Attribute dataset to have the following structure:

Dataset/
|-- VOCdevkit/
|---- VOC2007/
|------ JPEGImages/
|------ Annotations/
|------ ImageSets/
......
|-- COCO2014/
|---- annotations/
|---- images/
|------ train2014/
|------ val2014/
......
|-- WIDER/
|---- Annotations/
|------ wider_attribute_test.json/
|------ wider_attribute_trainval.json/
|---- Image/
|------ train/
|------ val/
|------ test/
...

Then directly run the following command to generate json file (for implementation) of these datasets.

python utils/prepare/voc.py  --data_path  Dataset/VOCdevkit
python utils/prepare/coco.py --data_path  Dataset/COCO2014
python utils/prepare/wider.py --data_path Dataset/WIDER

which will automatically result in json files in ./data/voc07, ./data/coco and ./data/wider

Demo

We provide prediction demos of our models. The demo images (picked from VCO2007) have already been put into ./utils/demo_images/, you can simply run demo.py by using our CSRA models pretrained on VOC2007:

CUDA_VISIBLE_DEVICES=0 python demo.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from OUR_VOC_PRETRAINED.pth --img_dir utils/demo_images

which will output like this:

utils/demo_images/000001.jpg prediction: dog,person,
utils/demo_images/000004.jpg prediction: car,
utils/demo_images/000002.jpg prediction: train,
...

Validation

We provide pretrained models on Google Drive for validation. ResNet101 trained on ImageNet with CutMix augmentation can be downloaded here.

Dataset Backbone Head nums mAP(%) Resolution Download
VOC2007 ResNet-101 1 94.7 448x448 download
VOC2007 ResNet-cut 1 95.2 448x448 download
COCO ResNet-101 4 83.3 448x448 download
COCO ResNet-cut 6 85.6 448x448 download
Wider VIT_B16_224 1 89.0 224x224 download
Wider VIT_L16_224 1 90.2 224x224 download

For voc2007, run the following validation example:

CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20  --load_from MODEL.pth

For coco2014, run the following validation example:

CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 4 --lam 0.5 --dataset coco --num_cls 80  --load_from MODEL.pth

For wider attribute with ViT models, run the following

CUDA_VISIBLE_DEVICES=0 python val.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14  --load_from ViT_B16_MODEL.pth
CUDA_VISIBLE_DEVICES=0 python val.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14  --load_from ViT_L16_MODEL.pth

To provide pretrained VIT models on Wider-Attribute dataset, we retrain them recently, which has a slightly different performance (~0.1%mAP) from what has been presented in our paper. The structure of the VIT models is the initial VIT version (An image is worth 16x16 words: Transformers for image recognition at scale, link) and the implementation code of the VIT models is derived from http://github.com/rwightman/pytorch-image-models/.

Training

VOC2007

You can run either of these two lines below

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20
CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --cutmix CutMix_ResNet101.pth

Note that the first command uses the Official ResNet-101 backbone while the second command uses the ResNet-101 pretrained on ImageNet with CutMix augmentation link (which is supposed to gain better performance).

MS-COCO

run the ResNet-101 with 4 heads

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.5 --dataset coco --num_cls 80

run the ResNet-101 (pretrained with CutMix) with 6 heads

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.4 --dataset coco --num_cls 80 --cutmix CutMix_ResNet101.pth

You can feel free to adjust the hyper-parameters such as number of attention heads (--num_heads), or the Lambda (--lam). Still, the default values of them in the above command are supposed to be the best.

Wider-Attribute

run the VIT_B16_224 with 1 heads

CUDA_VISIBLE_DEVICES=0 python main.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14

run the VIT_L16_224 with 1 heads

CUDA_VISIBLE_DEVICES=0,1 python main.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14

Note that the VIT_L16_224 model consume larger GPU space, so we use 2 GPUs to train them.

Notice

To avoid confusion, please note the 4 lines of code in Figure 1 (in paper) is only used in test stage (without training), which is our motivation. When our model is end-to-end training and testing, multi-head-attention (H=1, H=2, H=4, etc.) is used with different T values. Also, when H=1 and T=infty, the implementation code of multi-head-attention is exactly the same with Figure 1.

Acknowledgement

We thank Lin Sui (http://www.lamda.nju.edu.cn/suil/) for his initial contribution to this project.

Comments
  • How to load your pre-trained model and train on my dataset?

    How to load your pre-trained model and train on my dataset?

    Hi,

    thank you so much for your great work. I'm doing a project with multi-label classification so I wonder how I can apply your pretrained model for image feature extraction? what I need is to extract feature of an image. Could you please give me some hints?

    Best regards, Hui

    opened by YHDASHEN 5
  • Is the code consistent with the description in the paper?

    Is the code consistent with the description in the paper?

    According to formula 5 and formula 6 in the paper, the class-specific residual attention (CSRA) feature f should be sent to the classifier to obtain the final logits, but in your code, you use the f as the final logits, what's the difference?

    opened by AlphaPlusTT 4
  • Combine with ASL

    Combine with ASL

    Thanks for you good work. When I combine MHA with ASL(https://github.com/Alibaba-MIIL/ASL), I found the result will decrease. Do you try to combine your work with ASL?

    opened by catherinezll95 2
  • Load Model İssue?

    Load Model İssue?

    When I try to load ''vit_L16_224_coco_head8_86.5.pth'' model in val.py, I get following error.

    Error(s) in loading state_dict for VIT_CSRA: Missing key(s) in state_dict: "classifier.multi_head.0.head.weight". Unexpected key(s) in state_dict: "head1.weight", "head1.bias", "head2.weight", "head2.bias", "head3.weight", "head3.bias", "head4.weight", "head4.bias", "head5.weight", "head5.bias", "head6.weight", "head6.bias", "head7.weight", "head7.bias", "head8.weight", "head8.bias", "head.weight", "head.bias". size mismatch for pos_embed: copying a param with shape torch.Size([1, 785, 1024]) from checkpoint, the shape in current model is torch.Size([1, 197, 1024]). File "C:\Users\osivaz61\Desktop\projects\python\retina\diseaseDetection\CSRA-master\val.py", line 79, in main model.load_state_dict(torch.load(args.load_from)) File "C:\Users\osivaz61\Desktop\projects\python\retina\diseaseDetection\CSRA-master\val.py", line 97, in main()

    As far as I understand the vit_csra.py file is not updated. Can you share the updated code?

    Thanks

    opened by osivaz61 1
  • some questions about val.py

    some questions about val.py

    Thanks for sharing your code,I have som questions about your project. In val.py, the definitionation of follows was empty. parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str)

    And I cannot find the saved model coding about this. How to use the val.py in your project? and explansion the model saved path clearly?

    opened by yangt1013 1
  • Transformer on WiderAttribute predicition

    Transformer on WiderAttribute predicition

    @Kevinz-code thanks for providing the source code and great work. I have few queries which is mentioned below

    1. The vit model predicts the attributes of the person right like male,trousers ... what is the accuracy your getting on this
    2. I looked into the VIT implementation and compared it with "rwightman " implementation , you have used the same implementation the only changes is in the classifier part ie MHA-->CSRA part right ??
    3. Is CSRA customed to only VIT transformer can it be implemented with other transformers like CrossVIT
    4. The size of the custom image dataset which i have is of range from 8056 to 124128 should i change the patchsize from 16 to 8/4 since resizing the small resolution creates the pixelation blur effect
    5. Didi u keep the same training params for all the 3 training models

    Please do share your thoughts Thansk in adavance

    opened by abhigoku10 1
  • How to use VIT-224 pretrain weight with 448 input size

    How to use VIT-224 pretrain weight with 448 input size

    Thanks for your excellent work! I noticed you use VIT-L16-224 to train on coco,with an input size of 448. May I ask how to use VIT-224 pretrain weight with 448 input size?

    opened by zizizihua 0
  • AssertionError: Input image size (448*448) doesn't match model (224*224)

    AssertionError: Input image size (448*448) doesn't match model (224*224)

    Hello, I'm training VIT_ B has the above error. It is probably known that the required input in vit is 224 and the data set used is wider_ Attribute, how to modify it?

    opened by sure7018 0
  • MobileNet implementation of CSRA

    MobileNet implementation of CSRA

    I'm trying to implement CSRA using MobileNet as the backbone, but I'm running into some troubles. This is kind of related to #5. First of all, from the paper it was not clear to me whether CSRA is to be applied before, after or instead of the classifier.

    ~~Now, I have a question: Which version of MobileNet was CSRA implemented into? In my case, I'm trying to use MobileNetV3Large~~ It's stated in the paper it's MobileNetV2

    In my use case, I would like to use MobileNetV3 classifcation head, except with a different number of target classes. Where is CSRA supposed to be placed?

    This is the structure of the MobileNetV3 classifier: WhatsApp Image 2022-05-19 at 10 18 23 AM

    Is the CSRA supposed to replace the Avg Pool on the (7,7,960) tensor? to replace the 1x1 Conv after the (1,1,1280) tensor? To take place after the last 1x1 Conv?

    I think most of the confusion comes from Fig 1 and Fig 2 in the CSRA paper.

    • in Fig 1, the output of the backbone is run through the classifier, then through CSRA. It is stated that Fig 1 is a special case of CSRA, but it still remains confusing.

    • In Fig 2, f seems to act directly as the classes scores, while the text previous to Eq 6 states "Finally, all these class-specific feature vectors are sent to the classifier to obtain the final logits". It is not clear in Fig 2 that the result of the CSRA module is sent to the classifier, AND it brings more confusion to the matter of where is the CSRA module supposed to be placed

    opened by ghylander 4
  • Cross Validation

    Cross Validation

    Hi,

    I have two questions:

    First : Why didn't you use K Fold Cross Validation? Second : What is the reason use different learning rate for classifier? Is it for faster convergence?

    I am trying to adapt CSRA to EfficientNetB3 on my multi-label dataset. Although I try various head and lambda numbers, I am getting worse results according to baseline model. What is your opinion? Is there also something different to try?

    Also there is class imbalance in my dataset. Is there need to make data augmentation to prevent class imbalance? Is CSRA a method affected by data augmentation?

    Thanks

    opened by osivaz61 2
  • Details about baseline resnet-101 in paper

    Details about baseline resnet-101 in paper

    Hi, thanks for your excellent work! But I'm confused of the detail about baseline-model settings in your paper.

    Take training resnet-101 without cutmix on coco2014 as an example:

    With the following training configurations as baseline setting, I get 81.3 mAP after 7 epochs (30 in total, still in training process...), which is much higher than that in your paper (79.4 mAP). python main.py --num_heads 4 --lam 0 --dataset coco --num_cls 80 --checkpoint coco14/resnet101

    So, what is the correct settings to reproduce the baseline result as in your paper? Thanks again.

    opened by zhifanlight 1
  • Some problems about vision transformer

    Some problems about vision transformer

    Hello, by combining the code and your paper, I have the following questions(about vit_ csra):

    In the code, the class token is not used in the input of the last CSRA module, so why set the class token in the code in "VIT_CSRA". Has the last MLP head used for classification in the vision transformer been deleted directly?

    opened by sure7018 1
  • Question about Attention Image or Heatmap Generation

    Question about Attention Image or Heatmap Generation

    Hi Kevinz, thanks for your awesome work. I'd like to do a visual analysis to get a better understanding of the CSRA. Could you please give me some advice on how to visualize the attention score (or heatmap, attention image)? Thank you very much!

    opened by wzjahucm 1
Owner
null
This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021.

PyTorch implementation of DAQ This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021. For more informatio

CV Lab @ Yonsei University 36 Nov 4, 2022
Official Repo for ICCV2021 Paper: Learning to Regress Bodies from Images using Differentiable Semantic Rendering

[ICCV2021] Learning to Regress Bodies from Images using Differentiable Semantic Rendering Getting Started DSR has been implemented and tested on Ubunt

Sai Kumar Dwivedi 83 Nov 27, 2022
Official code for "Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021".

Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021. Introduction We proposed a novel model training paradi

Lucas 103 Dec 14, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
[ICCV2021] Official code for "Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition"

CTR-GCN This repo is the official implementation for Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition. The pap

Yuxin Chen 148 Dec 16, 2022
Code and models for ICCV2021 paper "Robust Object Detection via Instance-Level Temporal Cycle Confusion".

Robust Object Detection via Instance-Level Temporal Cycle Confusion This repo contains the implementation of the ICCV 2021 paper, Robust Object Detect

Xin Wang 69 Oct 13, 2022
Code for Talk-to-Edit (ICCV2021). Paper: Talk-to-Edit: Fine-Grained Facial Editing via Dialog.

Talk-to-Edit (ICCV2021) This repository contains the implementation of the following paper: Talk-to-Edit: Fine-Grained Facial Editing via Dialog Yumin

Yuming Jiang 221 Jan 7, 2023
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Ren Yurui 261 Jan 9, 2023
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Website | ArXiv | Get Start | Video PIRenderer The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic

Ren Yurui 81 Sep 25, 2021
Code for ICCV2021 paper SPEC: Seeing People in the Wild with an Estimated Camera

SPEC: Seeing People in the Wild with an Estimated Camera [ICCV 2021] SPEC: Seeing People in the Wild with an Estimated Camera, Muhammed Kocabas, Chun-

Muhammed Kocabas 187 Dec 26, 2022
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

Kemal Oksuz 229 Dec 20, 2022
Official implementation of "A Unified Objective for Novel Class Discovery", ICCV2021 (Oral)

A Unified Objective for Novel Class Discovery This is the official repository for the paper: A Unified Objective for Novel Class Discovery Enrico Fini

Enrico Fini 118 Dec 26, 2022
Implementation of ICCV2021(Oral) paper - VMNet: Voxel-Mesh Network for Geodesic-aware 3D Semantic Segmentation

VMNet: Voxel-Mesh Network for Geodesic-Aware 3D Semantic Segmentation Created by Zeyu HU Introduction This work is based on our paper VMNet: Voxel-Mes

HU Zeyu 82 Dec 27, 2022
Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation

Implicit Internal Video Inpainting Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation paper | project

null 202 Dec 30, 2022
This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

null 18 Sep 2, 2022
PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation

StructDepth PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimat

SJTU-ViSYS 112 Nov 28, 2022
ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

Zongdai 107 Dec 20, 2022