[NeurIPS 2019] Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Overview

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, Tengyu Ma


This is the official implementation of LDAM-DRW in the paper Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss in PyTorch.

Dependency

The code is built with following libraries:

Dataset

  • Imbalanced CIFAR. The original data will be downloaded and converted by imbalancec_cifar.py.
  • The paper also reports results on Tiny ImageNet and iNaturalist 2018. We will update the code for those datasets later.

Training

We provide several training examples with this repo:

  • To train the ERM baseline on long-tailed imbalance with ratio of 100
python cifar_train.py --gpu 0 --imb_type exp --imb_factor 0.01 --loss_type CE --train_rule None
  • To train the LDAM Loss along with DRW training on long-tailed imbalance with ratio of 100
python cifar_train.py --gpu 0 --imb_type exp --imb_factor 0.01 --loss_type LDAM --train_rule DRW

Reference

If you find our paper and repo useful, please cite as

@inproceedings{cao2019learning,
  title={Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss},
  author={Cao, Kaidi and Wei, Colin and Gaidon, Adrien and Arechiga, Nikos and Ma, Tengyu},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}
Comments
  • Questions about the hyper-parameters for LDAM loss

    Questions about the hyper-parameters for LDAM loss

    It was a very interesting paper to read :)

    I have some questions regarding the hyper-parameters for LDAM loss.

    1. What is the values of C, the hyper-parameter to be tuned (according to the paper)? Is it (max_m / np.max(m_list)) introduced in below? https://github.com/kaidic/LDAM-DRW/blob/master/losses.py#L28

    2. Is s=30 in LDAM loss also a hyper-parameter to be tuned? I could not find any explanation in the paper. Did I miss something?

    3. What were the tendency of these hyper-parameters when training? How do these hyper-parameter selections are related to the imbalance level (or different datasets)? The found parameters work for other datasets in the paper (Tiny ImageNet, iNaturalist)?

    Thanks.

    opened by hyungwonchoi 3
  • Wrong implementation of focal loss

    Wrong implementation of focal loss

    Hi,

    I believe that you have a wrong implementation of focal loss. I hope I have not misunderstood the code. Although the wrong implementation of focal loss will not effect the method you proposed. I hope the authors will spend some time correcting it.

    You should compute -(1-p)^r * log(p) for every sample in the batch. However, after you use F.cross_entropy at line 21 of losses.py , the output is already a single "value". You then use this value as p to compute focal loss which is completely wrong.

    An obvious indication of the wrong implementation is that you can actually remove the .mean() at line 11 in losses.py without causing any errors. It shows that you're indeed dealing with a single value but not vectors.

    This might explain why your implementation is so different from https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py or https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py

    You can also check the previous work you've cited https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py where the key point is that they make sure "reduction=none" when using F.binary_cross_entropy_with_logits.

    opened by AlanChou 2
  • Focal loss would lead to nan?

    Focal loss would lead to nan?

    Hi @kaidic

    Thanks for your fantastic work, but when I tried to reproduce the focal loss result, I found that when gamma=0.5, the focal loss would lead to nan loss during training, but the focal loss in this repo can make it.

    I checked the two different designed focal loss carefully and found the forward progress of them are the same but model parameters became different after backward, I am quite confused, could you please give me some advice?

    Thanks for your contribution again!

    opened by mitming 1
  • DRW actually use Class-Balance Weight, instead of Inverse of Frequency

    DRW actually use Class-Balance Weight, instead of Inverse of Frequency

    Hello, thanks for the paper and the code. I just want to confirm in the code snip:

    elif args.train_rule == 'DRW':
                train_sampler = None
                idx = epoch // 160
                betas = [0, 0.9999]
                effective_num = 1.0 - np.power(betas[idx], cls_num_list) # when epoch < 160, effective_num=1 (no reweighting). When epoch >160, reweighting with beta=0.9999
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    

    This is the implementation of Class-Balanced (slightly differ with inverse of freq reported in the paper). Any reason to select Beta=0.9999

    opened by chuong98 1
  • CE+DRW and CE+CB

    CE+DRW and CE+CB

    Thanks for your paper and your code, they are great work and help me a lot. Your article says that DRW is based on the number of samples, but your code is based on the weight of CB. I want to know whether the DRW reported in your article is CE + CB or CE + 1 / N?

    opened by xiaohua-chen 0
  • the learning rate in log_train times 0.1

    the learning rate in log_train times 0.1

    opened by yibuxulong 0
  • Can not achieve similar results For Tiny ImageNet

    Can not achieve similar results For Tiny ImageNet

    Thanks for your paper and your code, they are great work and help me a lot. I did experiments on tiny imagenet dataset following the settings revealed on your paper, howerer i can't achieve similar results, for long tailed 1:100 tiny imagenet, the top-1 validation error I got is: ERM SGD: 80.05 LDAM SGD: 72.8 It has a big gap with the results showed in your paper. So I wonder if there is any setting or trick I have missed? In the you mentioned:<We perform 1 crop test with the validation images.> I wonder how it is done specifically. For ResNet-18, I use: backbone = models.resnet18(pretrained=True) backbone.avgpool = nn.AdaptiveAvgPool2d(1) num_ftrs = backbone.fc.in_features if USE_NORM: backbone.fc = NormedLinear(num_ftrs, 200) else: backbone.fc = nn.Linear(num_ftrs, 200) Is it correct? Looking forward to your reply, thank you very much!

    opened by MapleLeafKiller 0
  • About the LDAM Loss

    About the LDAM Loss

    Thanks for your code a lot! I have read your paper and code, it's really a good idea, but here I have a question about LDAM Loss. It's in the last line where we call the basic cross_entropy function in pytorch.

        def forward(self, x, target):
            index = torch.zeros_like(x, dtype=torch.uint8)
            index.scatter_(1, target.data.view(-1, 1), 1)
    
            index_float = index.type(torch.cuda.FloatTensor)
            # self.m_list[None, :] add one dimension to the origin m_list
            batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
            # equivalently transpose
            batch_m = batch_m.view((-1, 1))
            x_m = x - batch_m
            # only the target labelpostion is x_m
            output = torch.where(index, x_m, x)
            return F.cross_entropy(self.s * output, target, weight=self.weight)
    

    why the output is multiplied by s(here is 30 times), just to make the loss greater? However, we didn't do this to the Focal loss

    opened by sakumashirayuki 4
  • more details about your paper

    more details about your paper

    Thanks for your code a lot! I have read your paper and code,it's really a good idea,but here I have a question about Formula 8.

    image

    why here y1 equals C/n^0.25?

    Anyway,thanks a lot!

    opened by zj-jayzhang 0
  • AttributeError: 'IMBALANCECIFAR10' object has no attribute 'data'

    AttributeError: 'IMBALANCECIFAR10' object has no attribute 'data'

    Hi,I meet "AttributeError" when running "cifar_train.py". Could you please tell me how to fix it ?

    Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/imbalance_cifar10/cifar-10-python.tar.gz 212664376it [00:19, 38731722.29it/s]Traceback (most recent call last): File "/xinfu/code/long_tail/BBN/main/train.py", line 69, in train_set = eval(cfg.DATASET.DATASET)("train", cfg) File "/xinfu/code/long_tail/BBN/lib/dataset/imbalance_cifar.py", line 25, in init img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) File "/xinfu/code/long_tail/BBN/lib/dataset/imbalance_cifar.py", line 44, in get_img_num_per_cls img_max = len(self.data) / cls_num AttributeError: 'IMBALANCECIFAR10' object has no attribute 'data'

    opened by xinfu607 1
Owner
Kaidi Cao
CS Ph.D. student
Kaidi Cao
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
Official implementation for the paper: "Multi-label Classification with Partial Annotations using Class-aware Selective Loss"

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

null 99 Dec 27, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

null 223 Dec 17, 2022
Official implementation of Influence-balanced Loss for Imbalanced Visual Classification in PyTorch.

Official implementation of Influence-balanced Loss for Imbalanced Visual Classification in PyTorch.

Seulki Park 70 Jan 3, 2023
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

null 120 Dec 12, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection

Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection (NimPme) The official implementation of Novel Instances Mining with

null 12 Sep 8, 2022
A PyTorch implementation of ICLR 2022 Oral paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

王皓波 83 May 11, 2022
Simple and Robust Loss Design for Multi-Label Learning with Missing Labels

Simple and Robust Loss Design for Multi-Label Learning with Missing Labels Official PyTorch Implementation of the paper Simple and Robust Loss Design

Xinyu Huang 28 Oct 27, 2022
Label Mask for Multi-label Classification

LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛【赛道一】设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签关联性,在多个数据集上测试表明该算法与主流算法无显著性差异,在该比赛数据集上的dev效果很好,但是由

null 52 Nov 20, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
An implementation for the loss function proposed in Decoupled Contrastive Loss paper.

Decoupled-Contrastive-Learning This repository is an implementation for the loss function proposed in Decoupled Contrastive Loss paper. Requirements P

Ramin Nakhli 71 Dec 4, 2022
Implement of "Training deep neural networks via direct loss minimization" in PyTorch for 0-1 loss

This is the implementation of "Training deep neural networks via direct loss minimization" published at ICML 2016 in PyTorch. The implementation targe

Cuong Nguyen 1 Jan 18, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
An implementation of a discriminant function over a normal distribution to help classify datasets.

CS4044D Machine Learning Assignment 1 By Dev Sony, B180297CS The question, report and source code can be found here. Github Repo Solution 1 Based on t

Dev Sony 6 Nov 9, 2021
Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Ng Kam Woh 71 Dec 22, 2022
Object detection on multiple datasets with an automatically learned unified label space.

Simple multi-dataset detection An object detector trained on multiple large-scale datasets with a unified label space; Winning solution of E

Xingyi Zhou 407 Dec 30, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 42 Jan 7, 2023
Official PyTorch implementation of "Camera Distance-aware Top-down Approach for 3D Multi-person Pose Estimation from a Single RGB Image", ICCV 2019

PoseNet of "Camera Distance-aware Top-down Approach for 3D Multi-person Pose Estimation from a Single RGB Image" Introduction This repo is official Py

Gyeongsik Moon 677 Dec 25, 2022