Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Overview

PWC PWC PWC PWC

Introdunction

This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Abstract

This paper presents a simple and effective approach to solving the multi-label classification problem. The proposed approach leverages Transformer decoders to query the existence of a class label. The use of Transformer is rooted in the need of extracting local discriminative features adaptively for different labels, which is a strongly desired property due to the existence of multiple objects in one image. The built-in cross-attention module in the Transformer decoder offers an effective way to use label embeddings as queries to probe and pool class-related features from a feature map computed by a vision backbone for subsequent binary classifications. Compared with prior works, the new framework is simple, using standard Transformers and vision backbones, and effective, consistently outperforming all previous works on five multi-label classification data sets, including MS-COCO, PASCAL VOC, NUS-WIDE, and Visual Genome. Particularly, we establish 91.3% mAP on MS-COCO. We hope its compact structure, simple implementation, and superior performance serve as a strong baseline for multi-label classification tasks and future studies.

fig

Results on MS-COCO:

fig

Quick start

  1. (optional) Star this repo.

  2. Clone this repo:

git clone [email protected]:SlongLiu/query2labels.git
cd query2labels
  1. Install cuda, PyTorch and torchvision.

Please make sure they are compatible. We test our models on two envs and other configs may also work:

cuda==11, torch==1.9.0, torchvision==0.10.0, python==3.7.3
or
cuda==10.2, torch==1.6.0, torchvision==0.7.0, python==3.7.3
  1. Install other needed packages.
pip install -r requirments.txt
  1. Data preparation.

Download MS-COCO 2014 and modify the path in lib/dataset/cocodataset.py: line 24, 25.

  1. Download pretrained models.

You could download pretrained models from this link. See more details below.

  1. Run!
python q2l_infer.py -a modelname --config /path/to/json/file --resume /path/to/pkl/file [other args]
e.g.
python q2l_infer.py -a 'Q2L-R101-448' --config "pretrained/Q2L-R101-448/config_new.json" -b 16 --resume 'pretrained/Q2L-R101-448/checkpoint.pkl'

pretrianed model

Modelname mAP link(Tsinghua-cloud)
Q2L-R101-448 84.9 this link
Q2L-R101-576 86.5 this link
Q2L-TResL-448 87.3 this link
Q2L-TResL_22k-448 89.2 this link
Q2L-SwinL-384 90.5 this link
Q2L-CvT_w24-384 91.3 this link

Training

Training scripts will be available later.

BibTex

@misc{liu2021query2label,
      title={Query2Label: A Simple Transformer Way to Multi-Label Classification}, 
      author={Shilong Liu and Lei Zhang and Xiao Yang and Hang Su and Jun Zhu},
      year={2021},
      eprint={2107.10834},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

We thank the authors of ASL, TResNet, detr, CvT, and Swin-Transformer for their great works and codes. Thanks to @mrT23 for sharing training tricks and providing a useful script for training.

Comments
  • Apply to custom datasets

    Apply to custom datasets

    Thank you very much for your work. I have the idea of applying it to my own data set, but it is not friendly to laypeople. I don't know where to start

    opened by Breeze-Zero 5
  • Question about the SLCutoutPIL.

    Question about the SLCutoutPIL.

    Hi, thanks for your impressive work. I have a question about the SLCoutoutPIL with the training parameters --cutout --n_holes 1 --cut_fact 0.5. According to your code train_data_transform_list.insert(1, SLCutoutPIL(n_holes=args.n_holes, length=args.length)), the cutout in your augmentation actually fill only a rectangle region with height=1 and width=1, which is a very small region. Is that right?

    opened by jasonseu 4
  • Pre-trained weights over Pascal VOC 2012

    Pre-trained weights over Pascal VOC 2012

    Hello. Thank you very much for making the code available.

    I was wondering if you still have the pre-trained weights over Pascal VOC 2012 dataset, and if those could be added here as well. As far as I can tell, the ones in the README.md refer to the COCO dataset, right? Moreover, their links seem to be broken.

    Cheers,

    opened by lucasdavid 2
  • an trian error :

    an trian error :

    Really an excellent paper, but I face an train error : subprocess.CalledProcessError: Command '['/home/anaconda3/envs/ql/bin/python,, '-u', 'main_mlc.py', '--local_rank=3', '--backbone', 'resnet101', '--dataname', 'voc07', '--batch-size', '64', '--print-freq', '100', '--world-size', '1', '--rank', '0', '--dist-url', 'tcp://127.0.0.1:3717', '--gamma_pos', '0', '--gamma_neg', '2', '--dtgfl', '--epochs', '80', '--lr', '1e-4', '--optim', 'AdamW', '--pretrained', '--num_class', '20', '--img_size', '448', '--weight-decay', '1e-2', '--cutout', '--n_holes', '1', '--cut_fact', '0.5', '--hidden_dim', '2048', '--dim_feedforward', '8192', '--enc_layers', '1', '--dec_layers', '2', '--nheads', '4', '--early-stop', '--amp']' returned non-zero exit status 1 how can I do it ,Thanks

    opened by fsted 2
  • Do you have any code or suggestions for visualization?

    Do you have any code or suggestions for visualization?

    Hi, I am student who current doing a deep learning project, I am wondering do you have any code or suggestion for visualization in figure 1, especially the part in the center. How did you implement it. The figure shows below. Looking forward to your reply. Thanks in advance. image

    opened by wzn123 2
  • Cannot train

    Cannot train

    I really appreciate your work and I want to train query2label on custom dataset. However, when I run the train code and the code showed the line:

    /usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 30 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
      cpuset_checked))
    [11/03 06:48:25.411]: lr:4.000000000000002e-06
    ^C
    

    Then code stopped without telling error. I don't know how to fix it. I hope you reply to me soon.

    opened by chauminhnguyen 2
  • training log

    training log

    Hi, query2label is an interesting project. I'm working on re-implementing it. However, I don't know if I got it right until I receive the final evaluation result. Therefore, would you mind releasing the training log? It will be very helpful. Thanks.

    opened by WasedaMagina 2
  • AttributeError: 'Namespace' object has no attribute 'arch' in main_mlc.py

    AttributeError: 'Namespace' object has no attribute 'arch' in main_mlc.py

    thank for your job. I find: [05/31 10:42:36.162]: 0 | Set best mAP 70.74814334530387 in ep 0 [05/31 10:42:36.162]: | best regular mAP 70.74814334530387 in ep 0 Traceback (most recent call last): File "main_mlc.py", line 727, in <module> main() File "main_mlc.py", line 224, in main return main_worker(args, logger) File "main_mlc.py", line 407, in main_worker 'arch': args.arch, AttributeError: 'Namespace' object has no attribute 'arch'

    there is no arch in main_mlc.py and input argparse but it could be found in q2l_infer.py

    so, is that some code error in this commit when I retrain this model in coco-dataset

    (by the way, there is another error: training with Q2L-SwinL-384 has position error in transformer part, resnet101 is fine)

    opened by macqueen09 1
  • About AsymmetricLossOptimized

    About AsymmetricLossOptimized

    Thanks to your great work! I have a question about the AsymmetricLossOptimized. What's the meaning of the this line _loss = _loss / y.size(1) * 1000 ?
    Where 1000 come from?

    opened by LeeYN-43 1
  • backbone about ViT

    backbone about ViT

    Hi, I have found your research "query2labels", which is really an excellent job! I really wonder whether Q2L can also improve the multi-label classification performance of ViT (VisionTransformer)? For example, is there any specific rise point value? And when Q2L is used for Vit, does the spatial features extracted by Vit is the class token of outputs (default classification settings in the original paper) or all outputs?

    Looking forward to your reply, best wishes!

    opened by xinyu1205 1
  • TResNet pre-trained on ImageNet-22k

    TResNet pre-trained on ImageNet-22k

    Thank you very much for your work. I have the idea of following it to my own work, but I can't find the pre-trained TResNet backbones on ImageNet-22k dataset with input resolution of 224. I would appreciate it if you can share.

    opened by sunfeng2016 1
  • anybody tried HARRISON dataset yet? I have no idea why it just top on epoch 0. help!!!!!

    anybody tried HARRISON dataset yet? I have no idea why it just top on epoch 0. help!!!!!

    when I try to put the CvT-w24 model on HARRISON dataset, which contains about 50000 images from instagram and the 997 kinds of real tags people post with. it just can't train well.....stoped at epoch 0 and start testing... do anybody know if this model can fit the HARRISON dataset cause there is a different between categories and real tags(including #happy, #love... which are not objects like coco)

    opened by jerry-zuo 0
  • Inconsistent number of samples

    Inconsistent number of samples

    We integrated the q2l into our codebae and it only works with swin_small 224 resolution, and all other backbones(densenet161, resnet50, swin_base) fail with 224 resolution with a variation of following error, in a non deterministic sense at different points like epoch 2, or 5 or 12 sometimes.

    Data=0.0004 s | (2515/2515) | 100.00% | Loss=0.0253 [########################################]> pred_scores shape: (90517, 17)
    Traceback (most recent call last):
      File "src/main/train.py", line 283, in <module>
        main(None, args)
      File "src/main/train.py", line 154, in main
        cfg, model, train_dataloader, criterion, optimizer, device, phase="train", scaler=scaler)
      File "src/main/helper/epoch.py", line 218, in epoch
        metrics_dict = calculate_metrics(cfg, gt_dict, pred_dict, phase, dataloader)
      File "src/main/helper/postprocess.py", line 79, in calculate_metrics
        gt_label, pred_score, "score", cfg, phase
      File "src/main/helper/postprocess.py", line 40, in evaluate_classif
        return_dict["{}_{}".format(phase, metric)] = callable_metric(gt, preds)
      File "/opt/venv/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 572, in roc_auc_score
        sample_weight=sample_weight,
      File "/opt/venv/lib/python3.7/site-packages/sklearn/metrics/_base.py", line 75, in _average_binary_score
        return binary_metric(y_true, y_score, sample_weight=sample_weight)
      File "/opt/venv/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 342, in _binary_roc_auc_score
        fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight)
      File "/opt/venv/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 963, in roc_curve
        y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
      File "/opt/venv/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 733, in _binary_clf_curve
        check_consistent_length(y_true, y_score, sample_weight)
      File "/opt/venv/lib/python3.7/site-packages/sklearn/utils/validation.py", line 334, in check_consistent_length
        % [int(l) for l in lengths]
    ValueError: Found input variables with inconsistent numbers of samples: [90517, 90185]
    

    the above error occured with

    backbone: swin_base
    resolution: 448
    dropout: 0
    encoder_layer: 0
    decoder_layer: 1
    hidden_dim: 1024
    dim_feedforward: 1024
    nheads: 4
    loss: asl
    neg: 4
    pos: 2
    num_classes=17
    

    we set the hidden_dim and dim_feedforward or more specifically the query embeddings shape to the decoder dimensions are set to equal the last layer dims of the backbone, for example in the above 1024 of swin_base, query embeddings are of the shape (17, 1024) Did anyone came across this kind of behaviour?

    opened by saishkomalla 0
  • 关于Swin-L[37]的实验结果

    关于Swin-L[37]的实验结果

    您好! 我想问一下,在Q2L里面报告的Swin-L[37]的实验结果: Q1: 你采用的Swin-L[37]使用的是什么池化操作? avg pooling 还是Max Pooling操作呢? Q2: 你采用的Swin-L[37]使用的是什么损失函数? BCE还是ASL的简化版? Q3: 可否提供一下计算CP/CR/CF1等评价指标的代码呢?

    WechatIMG4806

    opened by mymuli 1
  • Code for visualization of multi-head attention maps

    Code for visualization of multi-head attention maps

    Hi Liu, Thank for your great work and sharing clean code. Can you share your visualization code of multi-head attention maps,I am new to transformer, and find it difficult to realize this. Hope for your help.

    Best Regards! Tan

    opened by myt889 1
  • Doubt in evaluation metric.

    Doubt in evaluation metric.

    My doubt is:

    1. Since you did not set a threshold for your confidence score, I also do not find any code to determine that. I wonder how do I determine which label is true when I want a prediction on an image?

    Thank you very much.

    opened by Chanfeechen 1
Owner
Shilong Liu
A spicy chicken. www.lsl.zone
Shilong Liu
Official implementation of AAAI-21 paper "Label Confusion Learning to Enhance Text Classification Models"

Description: This is the official implementation of our AAAI-21 accepted paper Label Confusion Learning to Enhance Text Classification Models. The str

null 101 Nov 25, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
Official implementation of the ICLR 2021 paper

You Only Need Adversarial Supervision for Semantic Image Synthesis Official PyTorch implementation of the ICLR 2021 paper "You Only Need Adversarial S

Bosch Research 272 Dec 28, 2022
Official implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis https://arxiv.org/abs/2011.13775

CIPS -- Official Pytorch Implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis Requirements pip install -r requi

Multimodal Lab @ Samsung AI Center Moscow 201 Dec 21, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

null 153 Dec 14, 2022
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

HRNet 367 Dec 27, 2022
The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

Facebook Research 1k Dec 31, 2022
PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick."

PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick." [Project page] [Paper

Gyungin Shin 59 Sep 25, 2022
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 29 Sep 2, 2022
The official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averaging Approach

Graph Optimizer This repo contains the official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averagin

Chenyu 109 Dec 23, 2022
Official Pytorch Implementation of: "ImageNet-21K Pretraining for the Masses"(2021) paper

ImageNet-21K Pretraining for the Masses Paper | Pretrained models Official PyTorch Implementation Tal Ridnik, Emanuel Ben-Baruch, Asaf Noy, Lihi Zelni

null 574 Jan 2, 2023
The official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang Gong, Yi Ma. "Fully Convolutional Line Parsing." *.

F-Clip — Fully Convolutional Line Parsing This repository contains the official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang

Xili Dai 115 Dec 28, 2022