imbalanced-DL: Deep Imbalanced Learning in Python
Overview
imbalanced-DL
(imported as imbalanceddl) is a Python package designed to make deep imbalanced learning easier for researchers and real-world users. From our experiences, we observe that to tackcle deep imbalanced learning, there is a need for a strategy. That is, we may not just address this problem with one single model or approach. Thus in this package, we seek to provide several strategies for deep imbalanced learning. The package not only implements several popular deep imbalanced learning strategies, but also provides benchmark results on several image classification tasks. Futhermore, this package provides an interface for implementing more datasets and strategies.
Strategy
We provide some baseline strategies as well as some state-of-the-are strategies in this package as the following:
- Empirical Risk Minimization (baseline strategy)
- Reweighting with Class Balance (CB) Loss
- Deferred Re-Weighting (DRW)
- Label Distribution Aware Margin (LDAM) Loss with DRW
- Mixup with DRW
- Remix with DRW
Environments
- This package is tested on Linux OS.
- You are suggested to use a different virtual environment so as to avoid package dependency issue.
- For Pyenv & Virtualenv users, you can follow the below steps to create a new virtual environment or you can also skip this step.
Pyenv & Virtualenv (Optinal)
- For dependency isolation, it's better to create another virtual environment for usage.
- The following will be the demo for creating and managing virtual environment.
- Install
pyenv
&virtualenv
first. pyenv virtualenv [version] [virtualenv_name]
- For example, if you'd like to use python 3.6.8, you can do:
pyenv virtualenv 3.6.8 TestEnv
- For example, if you'd like to use python 3.6.8, you can do:
mkdir [dir_name]
cd [dir_name]
pyenv local [virtualenv_name]
- Then, you will have a new (clean) python virtual environment for the package installation.
Installation
Basic Requirement
- Python >= 3.6
git clone https://github.com/ntucllab/imbalanced-DL.git
cd imbalanceddl
python -m pip install -r requirements.txt
python setup.py install
Usage
We highlight three key features of imbalanced-DL
as the following:
(0) Imbalanced Dataset:
- We support 5 benchmark image datasets for deep imbalanced learing.
- To create and ImbalancedDataset object, you will need to provide a
config_file
as well as the dataset name you would like to use. - Specifically, inside the
config_file
, you will need to specify three key parameters for creating imbalanced dataset.imb_type
: you can choose fromexp
(long-tailed imbalance) orstep
imbalanced type.imb_ratio
: you can specify the imbalanceness of your data, typically researchers choose0.1
or0.01
.dataset_name
: you can specify 5 benchmark image datasets we provide, or you can implement your own dataset.- For an example of the
config_file
, you can see example/config.
- To contruct your own dataset, you should inherit from
BaseDataset
, and you can followtorchvision.datasets.ImageFolder
to construct your dataset in PyTorch format.
from imbalanceddl.dataset.imbalance_dataset import ImbalancedDataset
# specify the dataset name
imbalance_dataset = ImbalancedDataset(config, dataset_name=config.dataset)
(1) Strategy Trainer:
- We support 6 different strategies for deep imbalance learning, and you can either choose to train from scratch, or evaluate with the best model after training. To evaluate with the best model, you can get more in-depth metrics such as per class accuracy for further evaluation on the performance of the selected strategy. We provide one trained model in example/checkpoint_cifar10.
- For each strategy trainer, it is associated with a
config_file
,ImbalancedDataset object
,model
, andstrategy_name
. - Specifically, the
config_file
will provide some training parameters, where the default settings for reproducing benchmark result can be found in example/config. You can also set these training parameters based on your own need. - For
model
, we currently provideresnet32
andresnet18
for reproducing the benchmark results. - We provide a
build_trainer()
function to return the specified trainer as the following.
from imbalanceddl.strategy.build_trainer import build_trainer
# specify the strategy
trainer = build_trainer(config,
imbalance_dataset,
model=model,
strategy=config.strategy)
# train from scratch
trainer.do_train_val()
# Evaluate with best model
trainer.eval_best_model()
- Or you can also just select the specific strategy you would like to use as:
from imbalanceddl.strategy import LDAMDRWTrainer
# pick the trainer
trainer = LDAMDRWTrainer(config,
imbalance_dataset,
model=model,
strategy=config.strategy)
# train from scratch
trainer.do_train_val()
# Evaluate with best model
trainer.eval_best_model()
- To construct your own strategy trainer, you need to inherit from
Trainer
class, where in your own strategy you will have to implementget_criterion()
andtrain_one_epoch()
method. After this you can choose whether to add your strategy tobuild_trainer()
function or you can just use it as the above demonstration.
(2) Benchmark research environment:
- To conduct deep imbalanced learning research, we provide example codes for training with different strategies, and provide benchmark results on five image datasets. To quickly start training CIFAR-10 with ERM strategy, you can do:
cd example
python main.py --gpu 0 --seed 1126 --c config/config_cifar10.yaml --strategy ERM
- Following the example code, you can not only get results from baseline training as well as state-of-the-art performance such as LDAM or Remix, but also use this environment to develop your own algorithm / strategy. Feel free to add your own strategy into this package.
- For more information about example and usage, please see the Example README
Benchmark Results
We provide benchmark results on 5 image datasets, including CIFAR-10, CIFAR-100, CINIC-10, SVHN, and Tiny-ImageNet. We follow standard procedure to generate imbalanced training dataset for these 5 datasets, and provide their top 1 validation accuracy results for research benchmark. For example, below you can see the result table of Long-tailed Imbalanced CIFAR-10 trained on different strategies. For more detailed benchmark results, please see example/README.md.
Long-tailed Imbalanced CIFAR-10
imb_type |
imb_factor |
Model | Strategy | Validation Top 1 |
---|---|---|---|---|
long-tailed | 100 | ResNet32 | ERM | 71.23 |
long-tailed | 100 | ResNet32 | DRW | 75.08 |
long-tailed | 100 | ResNet32 | LDAM-DRW | 77.75 |
long-tailed | 100 | ResNet32 | Mixup-DRW | 82.11 |
long-tailed | 100 | ResNet32 | Remix-DRW | 81.82 |
Test
python -m unittest -v
Contact
If you have any question, please don't hesitate to email [email protected]
. Thanks !
Acknowledgement
The authors thank members of the Computational Learning Lab at National Taiwan University for valuable discussions and various contributions to making this package better.