LibMTL: A PyTorch Library for Multi-Task Learning

Overview

LibMTL

Documentation Status License: MIT PyPI version Supported Python versions Downloads CodeFactor Maintainability Made With Love

LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and API instructions.

Star us on GitHub — it motivates us a lot!

Table of Content

Features

  • Unified: LibMTL provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.
  • Comprehensive: LibMTL supports 84 MTL models combined by 7 architectures and 12 loss weighting strategies. Meanwhile, LibMTL provides a fair comparison on 3 computer vision datasets.
  • Extensible: LibMTL follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel loss weighting strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support of LibMTL.

Overall Framework

framework.

  • Config Module: Responsible for all the configuration parameters involved in the running framework, including the parameters of optimizer and learning rate scheduler, the hyper-parameters of MTL model, training configuration like batch size, total epoch, random seed and so on.
  • Dataloaders Module: Responsible for data pre-processing and loading.
  • Model Module: Responsible for inheriting classes architecture and weighting and instantiating a MTL model. Note that the architecture and the weighting strategy determine the forward and backward processes of the MTL model, respectively.
  • Losses Module: Responsible for computing the loss for each task.
  • Metrics Module: Responsible for evaluating the MTL model and calculating the metric scores for each task.

Supported Algorithms

LibMTL currently supports the following algorithms:

  • 12 loss weighting strategies.
Weighting Strategy Venues Comments
Equally Weighting (EW) - Implemented by us
Gradient Normalization (GradNorm) ICML 2018 Implemented by us
Uncertainty Weights (UW) CVPR 2018 Implemented by us
MGDA NeurIPS 2018 Referenced from official PyTorch implementation
Dynamic Weight Average (DWA) CVPR 2019 Referenced from official PyTorch implementation
Geometric Loss Strategy (GLS) CVPR 2019 workshop Implemented by us
Projecting Conflicting Gradient (PCGrad) NeurIPS 2020 Implemented by us
Gradient sign Dropout (GradDrop) NeurIPS 2020 Implemented by us
Impartial Multi-Task Learning (IMTL) ICLR 2021 Implemented by us
Gradient Vaccine (GradVac) ICLR 2021 Spotlight Implemented by us
Conflict-Averse Gradient descent (CAGrad) NeurIPS 2021 Referenced from official PyTorch implementation
Random Loss Weighting (RLW) arXiv Implemented by us
  • 7 architectures.
Architecture Venues Comments
Hrad Parameter Sharing (HPS) ICML 1993 Implemented by us
Cross-stitch Networks (Cross_stitch) CVPR 2016 Implemented by us
Multi-gate Mixture-of-Experts (MMoE) KDD 2018 Implemented by us
Multi-Task Attention Network (MTAN) CVPR 2019 Referenced from official PyTorch implementation
Customized Gate Control (CGC) ACM RecSys 2020 Best Paper Implemented by us
Progressive Layered Extraction (PLE) ACM RecSys 2020 Best Paper Implemented by us
DSelect-k NeurIPS 2021 Referenced from official TensorFlow implementation
  • 84 combinations of different architectures and loss weighting strategies.

Installation

The simplest way to install LibMTL is using pip.

pip install -U LibMTL

More details about environment configuration is represented in Docs.

Quick Start

We use the NYUv2 dataset as an example to show how to use LibMTL.

Download Dataset

The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here.

Run a Model

The complete training code for the NYUv2 dataset is provided in examples/nyu. The file train_nyu.py is the main file for training on the NYUv2 dataset.

You can find the command-line arguments by running the following command.

python train_nyu.py -h

For instance, running the following command will train a MTL model with EW and HPS on NYUv2 dataset.

python train_nyu.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step

More details is represented in Docs.

Citation

If you find LibMTL useful for your research or development, please cite the following:

@misc{LibMTL,
 author = {Baijiong Lin and Yu Zhang},
 title = {LibMTL: A PyTorch Library for Multi-Task Learning},
 year = {2021},
 publisher = {GitHub},
 journal = {GitHub repository},
 howpublished = {\url{https://github.com/median-research-group/LibMTL}}
}

Contributors

LibMTL is developed and maintained by Baijiong Lin and Yu Zhang.

Contact Us

If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to [email protected].

Acknowledgements

We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, and mtan.

License

LibMTL is released under the MIT license.

Comments
  • About MGDA implementation, some details I want to confirm.

    About MGDA implementation, some details I want to confirm.

    Hi, thanks for your wonderful project. There are some questions I want to confirm when I apply the MGDA weighting method, could you please give me an answer, thanks!

    1. what's the self.rep_tasks?
    2. what's the rep_grad?

    For the above two problems, I try to give my answer, the first is the representation generated by the representation layer (sharing parameters), and the second is whether using the gradients of representations.

    In this case, my third problem is what is the purpose of the variable rep_grad when it is in MGDA?

    It's used to implement the MGDA-UB? I realize that it will save the gradients of self.rep_tasks in the function of _compute_grad() in abstract_weighting.py, so I made such an assumption.

    I'm a little bit confused about these technique detail, hope you can help me, thanks again!

    opened by A11en0 12
  • trainer can't work

    trainer can't work

    when I try to run this trainer , .next() can't run

    -> "AttributeError: 'dict_keyiterator' object has no attribute 'next'"

    I use: python=3.7 torch=11.3

    This method has been deleted in python 3? Why not use next(iter)?

    or what should I do to fix it?

    image image

    opened by Sakura4036 10
  • DataLoader errors when I set num_workers>1

    DataLoader errors when I set num_workers>1

    I have found someone says set num_workers=0 will work,but it's too slow...My system is Ubuntu

    Exception ignored in: <function _MultiProcessingDataLoaderIter.del at 0x7f7c408fce60> Traceback (most recent call last): File "/home/user/miniconda3/envs/IB/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1510, in del self._shutdown_workers() File "/home/user/miniconda3/envs/IB/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers if w.is_alive(): File "/home/user/miniconda3/envs/IB/lib/python3.7/multiprocessing/process.py", line 151, in is_alive assert self._parent_pid == os.getpid(), 'can only test a child process' AssertionError: can only test a child process

    opened by alphabet-lgtm 8
  •  office-31 demo doesn't work

    office-31 demo doesn't work

    I ran the demo of office-31and found some errors in this example: The version of pip is not the same version published on github The operating environment is as follows:

    • python = 3.8
    • pytroch = 1.12

    In train.py, .next() can't run image

    After I modified this syntax, the new problem appeared again data = data.to(self.device, non_blocking=True)

    I want to know if this is a version reason or something else

    opened by PussInCode 6
  • for GNN

    for GNN

    Hi, I know it's kinda OOT but I am curious whether I can apply multi-task learning to graph neural network. What I learn from HPS, we shall share the encoder/decoder across the layer. I am curious should I create an encoder on top of the graph layer? Kinda stuck in this experiment, any suggestion would be helpful. Thanks

    opened by aozorahime 5
  • Support for Nash-MTL

    Support for Nash-MTL

    Thanks for this great repo, it is very useful!

    Could you please add support for the Nash-MTL method described in the paper "Multi-Task Learning as a Bargaining Game"?

    Paper: https://arxiv.org/abs/2202.01017 Official code: https://github.com/AvivNavon/nash-mtl

    Thanks

    opened by AvivNavon 5
  • some problem about metrics.py

    some problem about metrics.py

    I introduced L1Metric class ,I get a error:

    AttributeError: 'L1Metric' object has no attribute 'abs_record'

    class L1Metric(AbsMetric):
        r"""Calculate the Mean Absolute Error (MAE).
        """
        def __init__(self):
            super(L1Metric, self).__init__()
            
        def update_fun(self, pred, gt):
            r"""
            """
            abs_err = torch.abs(pred - gt)
            self.record.append(abs_err)
            self.bs.append(pred.size()[0])
            
        def score_fun(self):
            r"""
            """
            records = np.array(self.abs_record)
            batch_size = np.array(self.bs)
            return [(records*batch_size).sum()/(sum(batch_size))]
    

    L1Metric class inherit AbsMetric class,but AbsMetric class has no attribute 'abs_record',So I guess there are maybe some problems,of course,this property may also come from other places

    class AbsMetric(object):
        r"""An abstract class for the performance metrics of a task. 
    
        Attributes:
            record (list): A list of the metric scores in every iteration.
            bs (list): A list of the number of data in every iteration.
        """
        def __init__(self):
            self.record = []
            self.bs = []
        
        @property
        def update_fun(self, pred, gt):
            r"""Calculate the metric scores in every iteration and update :attr:`record`.
    
            Args:
                pred (torch.Tensor): The prediction tensor.
                gt (torch.Tensor): The ground-truth tensor.
            """
            pass
        
        @property
        def score_fun(self):
            r"""Calculate the final score (when an epoch ends).
    
            Return:
                list: A list of metric scores.
            """
            pass
        
        def reinit(self):
            r"""Reset :attr:`record` and :attr:`bs` (when an epoch ends).
            """
            self.record = []
            self.bs = []
    
    opened by PussInCode 3
  • DDP mode and amp

    DDP mode and amp

    你好,非常感谢您的工作,但在多卡使用时出现了如下问题想请教: https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/abstract_weighting.py#L43 image

    配置:MGDA+hardParam

    另外还有个问题想请教一下:

    1. 项目中的weighter类型有哪些和半精度amp训练兼容? 兼容的话应该在scale之前进行weighting还是scale之后weighting?
    opened by apxlwl 3
  • Unable to save the trained model

    Unable to save the trained model

    When I try to save the trained model i.e. full model using the following command -

    torch.save(model, "<path>")

    It throws this error

    AttributeError: Can't pickle local object 'Trainer._prepare_model.<locals>.MTLmodel'

    opened by shubham166 2
  • Trainer class use fails with the error

    Trainer class use fails with the error "No module named 'torchvision.models.utils'"

    Trainer class use fails with the error "No module named 'torchvision.models.utils'"

    Full stack here - Traceback (most recent call last): File "src/main/pipelines/train_nsfw_mtl.py", line 11, in from LibMTL import Trainer File "/azureml-envs/azureml_8a26314e09753d45d0790003a01faf79/lib/python3.8/site-packages/LibMTL/init.py", line 2, in from . import model File "/azureml-envs/azureml_8a26314e09753d45d0790003a01faf79/lib/python3.8/site-packages/LibMTL/model/init.py", line 1, in from LibMTL.model.resnet import resnet18 File "/azureml-envs/azureml_8a26314e09753d45d0790003a01faf79/lib/python3.8/site-packages/LibMTL/model/resnet.py", line 3, in from torchvision.models.utils import load_state_dict_from_url ModuleNotFoundError: No module named 'torchvision.models.utils'

    The issue is fixed by using 'torch.hub' instead.

    opened by shubham166 2
  • uw initialization

    uw initialization

    Hi, I found that the value -0.5 was used when initializing the parameter in line 19 of uw.py. My doubt is why this value is not 0, since the variable loss_scale is equivalent to log \sigma in the original paper.

    self.loss_scale = nn.Parameter(torch.tensor([-0.5]*self.task_num, device=self.device))
    
    opened by antct 2
Owner
null
Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling

TGraM Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling, Qibin He, Xian Sun, Zhiyuan Yan, Beibei Li, Kun Fu Abstract Rece

Qibin He 6 Nov 25, 2022
Code and pre-trained models for MultiMAE: Multi-modal Multi-task Masked Autoencoders

MultiMAE: Multi-modal Multi-task Masked Autoencoders Roman Bachmann*, David Mizrahi*, Andrei Atanov, Amir Zamir Website | arXiv | BibTeX Official PyTo

Visual Intelligence & Learning Lab, Swiss Federal Institute of Technology (EPFL) 385 Jan 6, 2023
AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning

AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning (NeurIPS 2020) Introduction AdaShare is a novel and differentiable approach fo

null 94 Dec 22, 2022
Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation", Haoxiang Wang, Han Zhao, Bo Li.

Bridging Multi-Task Learning and Meta-Learning Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Trainin

AI Secure 57 Dec 15, 2022
Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Improving evidential deep learning via multi task learning It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task le

deargen 11 Nov 19, 2022
Implementation of PyTorch-based multi-task pre-trained models

mtdp Library containing implementation related to the research paper "Multi-task pre-training of deep neural networks for digital pathology" (Mormont

Romain Mormont 27 Oct 14, 2022
A list of multi-task learning papers and projects.

This page contains a list of papers on multi-task learning for computer vision. Please create a pull request if you wish to add anything. If you are interested, consider reading our recent survey paper.

svandenh 297 Dec 17, 2022
A list of multi-task learning papers and projects.

A list of multi-task learning papers and projects.

svandenh 84 Apr 27, 2021
RoboDesk A Multi-Task Reinforcement Learning Benchmark

RoboDesk A Multi-Task Reinforcement Learning Benchmark If you find this open source release useful, please reference in your paper: @misc{kannan2021ro

Google Research 66 Oct 7, 2022
Multi-task Learning of Order-Consistent Causal Graphs (NeuRIPs 2021)

Multi-task Learning of Order-Consistent Causal Graphs (NeuRIPs 2021) Authors: Xinshi Chen, Haoran Sun, Caleb Ellington, Eric Xing, Le Song Link to pap

Xinshi Chen 2 Dec 20, 2021
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

null 2 Sep 22, 2022
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Pedro Neto 21 Nov 17, 2022
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
Multi Task RL Baselines

MTRL Multi Task RL Algorithms Contents Introduction Setup Usage Documentation Contributing to MTRL Community Acknowledgements Introduction M

Facebook Research 171 Jan 9, 2023
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Xin Liu 106 Dec 30, 2022
Multi Task Vision and Language

12-in-1: Multi-Task Vision and Language Representation Learning Please cite the following if you use this code. Code and pre-trained models for 12-in-

Facebook Research 712 Dec 19, 2022
Exploring Relational Context for Multi-Task Dense Prediction [ICCV 2021]

Adaptive Task-Relational Context (ATRC) This repository provides source code for the ICCV 2021 paper Exploring Relational Context for Multi-Task Dense

David Brüggemann 35 Dec 5, 2022
This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams.

Mutli-agent task allocation This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams. To change

Biorobotics Lab 5 Oct 12, 2022