ML-Decoder: Scalable and Versatile Classification Head

Overview

ML-Decoder: Scalable and Versatile Classification Head

PWC
PWC
PWC


Paper

Official PyTorch Implementation

Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba Group

Abstract

In this paper, we introduce ML-Decoder, a new attention-based classification head. ML-Decoder predicts the existence of class labels via queries, and enables better utilization of spatial data compared to global average pooling. By redesigning the decoder architecture, and using a novel group-decoding scheme, ML-Decoder is highly efficient, and can scale well to thousands of classes. Compared to using a larger backbone, ML-Decoder consistently provides a better speed-accuracy trade-off. ML-Decoder is also versatile - it can be used as a drop-in replacement for various classification heads, and generalize to unseen classes when operated with word queries. Novel query augmentations further improve its generalization ability. Using ML-Decoder, we achieve state-of-the-art results on several classification tasks: on MS-COCO multi-label, we reach 91.4% mAP; on NUS-WIDE zero-shot, we reach 31.1% ZSL mAP; and on ImageNet single-label, we reach with vanilla ResNet50 backbone a new top score of 80.7%, without extra data or distillation.

ML-Decoder Implementation

ML-Decoder implementation is available here. It can be easily integrated into any backbone using this example code:

ml_decoder_head = MLDecoder(num_classes) # initilization

spatial_embeddings = self.backbone(input_image) # backbone generates spatial embeddings      
 
logits = ml_decoder_head(spatial_embeddings) # transfrom spatial embeddings to logits

Training Code

We will share a full reproduction code for the article results.

Multi-label Training Code


A reproduction code for MS-COCO multi-label:

python train.py  \
--data=/home/datasets/coco2014/ \
--model_name=tresnet_l \
--image_size=448

Single-label Training Code

Our single-label training code uses the excellent timm repo. Reproduction code is currently from a fork, we will work toward a full merge to the main repo.

git clone https://github.com/mrT23/pytorch-image-models.git

This is the code for A2 configuration training, with ML-Decoder (--use-ml-decoder-head=1):

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train.py \
/data/imagenet/ \
--amp \
-b=256 \
--epochs=300 \
--drop-path=0.05 \
--opt=lamb \
--weight-decay=0.02 \
--sched='cosine' \
--lr=4e-3 \
--warmup-epochs=5 \
--model=resnet50 \
--aa=rand-m7-mstd0.5-inc1 \
--reprob=0.0 \
--remode='pixel' \
--mixup=0.1 \
--cutmix=1.0 \
--aug-repeats 3 \
--bce-target-thresh 0.2 \
--smoothing=0 \
--bce-loss \
--train-interpolation=bicubic \
--use-ml-decoder-head=1

ZSL Training Code

Reproduction code for ZSL is WIP.

Citation

@misc{ridnik2021mldecoder,
      title={ML-Decoder: Scalable and Versatile Classification Head}, 
      author={Tal Ridnik and Gilad Sharir and Avi Ben-Cohen and Emanuel Ben-Baruch and Asaf Noy},
      year={2021},
      eprint={2111.12933},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Comments
  • Question about reproducing on the Open Images dataset

    Question about reproducing on the Open Images dataset

    Thank you for your great work! Appreciate your groundbreaking research of multi-label classification on OpenImages-v6.

    I had difficulty reproducing your 86.6 mAP on OpenImages. I used the get_datasets_from_csv to get train/val dataset, I passed in a json file with full 9605 classes as train/test classes. The training parameters are set according to the paper. The resulted mAP is 33. I got it from the validate_multi function in train.py.

    I validated your pretrained model on openimages( https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ML_Decoder/tresnet_m_open_images_200_groups_86_8.pth) with validate.py, the mAP is 28, while I expect it to be 86.8.

    I load the state_dict with model.load_state_dict(state['model'], strict=True), so I think the model is correctly loaded with state_dict. Do you know why the mAP is such low? Thank you!

    opened by yankuai 7
  • verified the COCO

    verified the COCO

    opened by sorrowyn 7
  • Training Pefromance of ZSL code ,final mAP with tresnet-m is 15.6

    Training Pefromance of ZSL code ,final mAP with tresnet-m is 15.6

    Hi, I have tried to train the tresnet-m model on nus-wide dataset but i got mAP of 15.6 after 40 epochs. May i ask you suggestion what am i doing wrong here , i didn't change any code. If you can answer this will help a lot Regards

    opened by aliman80 7
  • pretraining for ML_Decoder’ backbone and openimagev6 dataset

    pretraining for ML_Decoder’ backbone and openimagev6 dataset

    This is a very Interesting work! May I ask you two questions? Q1: You used Open Images pretraining for ML_Decoder’ backbone to get tresnet_l_pretrain_ml_decoder. parser.add_argument('--model-path', default='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ML_Decoder/tresnet_l_pretrain_ml_decoder.pth', type=str) 图片

    Q2:Will you release the data loading file for openimage?

    opened by sorrowyn 7
  • about multi GPU training error

    about multi GPU training error

    Hi, while I training this model on multi GPUs some error happened My code is `from torch.utils.data.distributed import DistributedSampler

    local_rank = torch.distributed.get_rank() torch.cuda.set_device(local_rank)

    model = create_model(args,load_head=True) model = torch.nn.DataParallel(model,device_ids=device_ids) model.to(device)

    for i, (inputData,target) in pbar: # for i, (inputData, target) in enumerate(train_loader): inputData = inputData.to(device) target = target.to(device) # (batch,3,num_classes) # target = target.max(dim=1)[0] with autocast(): # mixed precision output = model(inputData).float() # sigmoid will be done in loss ! # print("target shape = ",target) # print("output shape = ",output) loss = criterion(output, target)`

    But error happened: RuntimeError: AssertionTHCTensor_(checkGPU)(state, 3, input, output, weight)' failed. Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one. at /opt/conda/conda-bld/pytorch_1603728993639/work/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu:16 `

    When I changed self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() to self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda(device=int(os.environ.get('RANK', 0))).half() as you mentioned in TResNet

    But the same error happened RuntimeError: attribute lookup is not defined on python value of type '_Environ': File "/home/kpl/code/multilabel/ML_Decoder/src_files/models/tresnet/layers/anti_aliasing.py", line 35 filt = filt / torch.sum(filt) # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda(device=int(os.environ.get('RANK', 0))).half() ~~~~~~~~~~~~~~ <--- HERE you said you added an option --remove_aa_jit. run with it, it should be ok. But I don't find --remove_aa_jit can you give some suggestions? Thanks very much.

    opened by myh12138 6
  • Release of pre-trained weights?

    Release of pre-trained weights?

    Thanks so much for this amazing new model.

    I was wondering if you would be able to release some of the pre-trained weights behind these core benchmark datasets?

    I pointed some colleagues to your paper and repo to show them what the SOTA was in this area, and couldn't easily point them to a collab notebook, or huggingface demo, or even build our own hosted demo to show them.

    Love your work, please keep up the great work!

    opened by eware-godaddy 6
  • Missing key(s) in state_dict:

    Missing key(s) in state_dict: "head.fc.weight", "head.fc.bias"

    Hello,

    I am fascinated by your great work and I'm trying to experiment with your code a little bit.

    I want to create an instance of the TResNet class and then load the pre-trained model for the Stanford Cars dataset into the model using PyTorch. However, it seems like the state dictionary of the TResNet class is not compatible with that of the pre-trained model that you have shared in the model zoo.

    Here is the code that I use:

    model = TResNet([3, 4, 23, 3], num_classes=80, in_chans=3, first_two_layers=Bottleneck).cuda()
    
    state = torch.load(PATH_TO_THE_PRETRAINED_MODEL, map_location='cpu')
    filtered_dict = {k: v for k, v in state['model'].items() if
                                 (k in model.state_dict() and 'head.fc' not in k)}
    # here is the issue!
    model = model.load_state_dict(filtered_dict, strict=True)
    model.eval()
    

    This is the error that I get:

    RuntimeError: Error(s) in loading state_dict for TResNet: Missing key(s) in state_dict: "head.fc.weight", "head.fc.bias".

    Even if I ignore this by exception handling, I get poor results on the test images; all of the classes have a score around 55% to 62%.

    Can you please help me solve this issue? Thank you in advance.

    opened by zahragolpa 4
  • Can not reproduce the tresnet-l mAP on coco

    Can not reproduce the tresnet-l mAP on coco

    Hi, first thanks for you to share your great work! Here, I rerun all models in the model_zoo and got the same mAP except the tresnet-l on COCO(85.57 in my workspace), maybe a wrong model is uploaded on git? I notice the name of tresnet-l is a little bit different from other models.

    opened by FunnyClown 3
  • Unable to install requirements

    Unable to install requirements

    I was trying to run the inference on Google Colab, but run into the following error while performing pip install -r requirements.txt

    Collecting git+https://github.com/mapillary/[email protected] (from -r requirements.txt (line 3))
      Cloning https://github.com/mapillary/inplace_abn.git (to revision v1.0.12) to /tmp/pip-req-build-nr36ha5b
      Running command git clone -q https://github.com/mapillary/inplace_abn.git /tmp/pip-req-build-nr36ha5b
      Running command git checkout -q 24fc791e6d4796a1639e7a5dce6fa67377e51a3e
    Requirement already satisfied: torch>=1.7 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (1.10.0+cu111)
    Requirement already satisfied: torchvision>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 2)) (0.11.1+cu111)
    Requirement already satisfied: pycocotools in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 4)) (2.0.4)
    Requirement already satisfied: randaugment in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (1.0.2)
    Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.7->-r requirements.txt (line 1)) (3.10.0.2)
    Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (7.1.2)
    Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.19.5)
    Requirement already satisfied: matplotlib>=2.1.0 in /usr/local/lib/python3.7/dist-packages (from pycocotools->-r requirements.txt (line 4)) (3.2.2)
    Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 4)) (0.11.0)
    Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 4)) (2.8.2)
    Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 4)) (3.0.7)
    Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 4)) (1.3.2)
    Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 4)) (1.15.0)
    Building wheels for collected packages: inplace-abn
      Building wheel for inplace-abn (setup.py) ... error
      ERROR: Failed building wheel for inplace-abn
      Running setup.py clean for inplace-abn
    Failed to build inplace-abn
    Installing collected packages: inplace-abn
        Running setup.py install for inplace-abn ... error
    ERROR: Command errored out with exit status 1: /usr/bin/python3 -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-nr36ha5b/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-nr36ha5b/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-1diqp1b7/install-record.txt --single-version-externally-managed --compile --install-headers /usr/local/include/python3.7/inplace-abn Check the logs for full command output.
    

    It seems to be something with installing inplace_abn

    opened by jmayank23 3
  • When the ml decoder is applied to my custom data, the loss does not decrease during training.

    When the ml decoder is applied to my custom data, the loss does not decrease during training.

    Thank you for your impressive project.

    When I using the ml decoder, training is not working properly. (Train loss doesn't decrease, accuracy also does not improve.)

    I used the resnet50 model pretrained with Imagenet 21 provided by your group.

    Backbone model was created by timm library.

     model = timm.create_model('resnet50', pretrained=False, num_classes=0)
     model = load_model_weights(model, os.path.join(os.path.split(os.path.realpath(__file__))[0],'resnet50_miil_21k.pth'))
    

    This is my Classifier.

    class Classifier(nn.Module):
        def __init__(self, cfg):
            super(Classifier, self).__init__()
            self.cfg = cfg
            self.backbone = BACKBONES[cfg.backbone](cfg)
            self.num_in_channel = self.get_num_channel() #number of output channels of backbone model
            self.ml_decoder_head = MLDecoder(self.cfg.num_class, initial_num_features=self.num_in_channel)
    
        def forward(self, x):
            # (N, C, H, W)
            x = self.backbone(x)
            x = self.ml_decoder_head(x)
            return x
    

    For logit x, nn.CrossEntropyLoss was applied.

    When I used the head by applying GAP and adding a linear layer, this issue did not occur.

    The number of classes is 10, and it is a single label.

    opened by jeongHwarr 3
  • Difference between GZSL and ZSL mAPs

    Difference between GZSL and ZSL mAPs

    Hi, thank you for your support, I am getting the mAP values after running the train.ZSL file s but how to obtain GZSL values. Also if you can guide how to get F1 scores for different K values as mentioned in paper; Can you share the computation for these, please.

    opened by aliman80 2
  • The position of dropout layer

    The position of dropout layer

    https://github.com/Alibaba-MIIL/ML_Decoder/blob/8a9e984f671c9c30c98d2c45dfcaf4383381c254/src_files/ml_decoder/ml_decoder.py#L60

    Thanks for sharing your work!

    Why is the dropout layer positioned here?

    What's the effect of it?

    Is it noise for generalization?

    opened by developer0hye 0
  • weird behaviour in stanford car

    weird behaviour in stanford car

    Hi, thanks for your amazing work!

    Here I encountered something really weird. I downloaded the tresnet_l_stanford_card_96.41.pth and tried to validate the result in stanford car datasets. It reached 99.69 for validation part and 96.80 for train part, respectively. Most likely the training and validation part were treated inversely.

    Can you please check if the order is correct?

    opened by LouieShao 0
  • About the pretrained backbone.

    About the pretrained backbone.

    Hi, thanks for your great work. The performance of ResNet101 based ML-Decoder on the MS-COCO is impressive. The training details in your paper show that the ResNet101 is pretrained on Open Images. Did you conduct experiment about the ResNet101 pretraining on ImageNet ?

    opened by jasonseu 0
  • infer.py doesn't work, variable referenced before assignment

    infer.py doesn't work, variable referenced before assignment

    Hi, thanks for the work! But I think there is an obvious bug at line 56 in src_files/models/utils/factory.py: model.load_state_dict(state[key], strict=True) The variable key in else branch is used without definition. So when load_head argument is True, which True in infer.py, the code doesn't work.

    opened by PT0X0E 0
  • The result under imagenet1k pretrained model

    The result under imagenet1k pretrained model

    Hello,thank you very much for sharing the work. And I find all the results in paper are under imagenet 21k pretrained model. Considering that a lot of work before used imagenet 1k, have you also tried the imagenet 1k pretrained ?

    opened by THUeeY 0
  • The experimental results of TResnet_M cannot be reproduced

    The experimental results of TResnet_M cannot be reproduced

    Hello, I used the pre-training model of TResnet_M, the image input size is 224, the learning rate is 2e-4 in the paper, the rest of the parameters are consistent with the training file, I only got 78.86 results instead of 84.2 in MODEL_ZERO.Do I need to make any changes?

    opened by limengyang1992 0
Owner
null
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

null 910 Dec 28, 2022
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
[ICCV 2021] Encoder-decoder with Multi-level Attention for 3D Human Shape and Pose Estimation

MAED: Encoder-decoder with Multi-level Attention for 3D Human Shape and Pose Estimation Getting Started Our codes are implemented and tested with pyth

ZiNiU WaN 176 Dec 15, 2022
Code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection"

CTDNet The PyTorch code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection" Requirements Python 3.6

CVTEAM 28 Oct 20, 2022
This repository contains the data and code for the paper "Diverse Text Generation via Variational Encoder-Decoder Models with Gaussian Process Priors" (SPNLP@ACL2022)

GP-VAE This repository provides datasets and code for preprocessing, training and testing models for the paper: Diverse Text Generation via Variationa

Wanyu Du 18 Dec 29, 2022
[CVPR 2021 Oral] ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis

ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis [arxiv|pdf|v

Yinan He 78 Dec 22, 2022
Totally Versatile Miscellanea for Pytorch

Totally Versatile Miscellania for PyTorch Thomas Viehmann [email protected] This repository collects various things I have implmented for PyTorch Laye

Thomas Viehmann 428 Dec 28, 2022
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 6, 2023
Versatile Generative Language Model

Versatile Generative Language Model This is the implementation of the paper: Exploring Versatile Generative Language Model Via Parameter-Efficient Tra

Zhaojiang Lin 17 Dec 2, 2022
Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

Mingyu Ding 36 Dec 6, 2022
Exploring Versatile Prior for Human Motion via Motion Frequency Guidance (3DV2021)

Exploring Versatile Prior for Human Motion via Motion Frequency Guidance This is the codebase for video-based human motion reconstruction in human-mot

Jiachen Xu 5 Jul 14, 2022
PyTorch implementation of SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching

SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching This is the official PyTorch implementation of SMODICE: Versatile Offline I

Jason Ma 14 Aug 30, 2022
[CVPR 2022 Oral] Versatile Multi-Modal Pre-Training for Human-Centric Perception

Versatile Multi-Modal Pre-Training for Human-Centric Perception Fangzhou Hong1  Liang Pan1  Zhongang Cai1,2,3  Ziwei Liu1* 1S-Lab, Nanyang Technologic

Fangzhou Hong 96 Jan 3, 2023
DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download

Bubbliiiing 31 Nov 25, 2022
An implementation of a sequence to sequence neural network using an encoder-decoder

Keras implementation of a sequence to sequence model for time series prediction using an encoder-decoder architecture. I created this post to share a

Luke Tonin 195 Dec 17, 2022
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

null 104 Dec 8, 2022
[ECCV 2020] Reimplementation of 3DDFAv2, including face mesh, head pose, landmarks, and more.

Stable Head Pose Estimation and Landmark Regression via 3D Dense Face Reconstruction Reimplementation of (ECCV 2020) Towards Fast, Accurate and Stable

Remilia Scarlet 221 Dec 30, 2022
Web service for facial landmark detection, head pose estimation, facial action unit recognition, and eye-gaze estimation based on OpenFace 2.0

OpenGaze: Web Service for OpenFace Facial Behaviour Analysis Toolkit Overview OpenFace is a fantastic tool intended for computer vision and machine le

Sayom Shakib 4 Nov 3, 2022
OpenFace – a state-of-the art tool intended for facial landmark detection, head pose estimation, facial action unit recognition, and eye-gaze estimation.

OpenFace 2.2.0: a facial behavior analysis toolkit Over the past few years, there has been an increased interest in automatic facial behavior analysis

Tadas Baltrusaitis 5.8k Dec 31, 2022