median-research-group / LibMTL

A PyTorch Library for Multi-Task Learning
MIT License
1.94k stars 181 forks source link
deep-learning mmoe mtl multi-domain-learning multi-objective-optimization multi-task-learning multiobjective-optimization multitask-learning ple python pytorch

LibMTL

Documentation Status License: MIT PyPI version Supported Python versions CodeFactor paper coverage Hits Made With Love

LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and API instructions.

:star: Star us on GitHub — it motivates us a lot!

News

Table of Content

Features

Overall Framework

framework

Each module is introduced in Docs.

Supported Algorithms

LibMTL currently supports the following algorithms:

Optimization Strategies Venues Arguments
Equal Weighting (EW) - --weighting EW
Gradient Normalization (GradNorm) ICML 2018 --weighting GradNorm
Uncertainty Weights (UW) CVPR 2018 --weighting UW
MGDA (official code) NeurIPS 2018 --weighting MGDA
Dynamic Weight Average (DWA) (official code) CVPR 2019 --weighting DWA
Geometric Loss Strategy (GLS) CVPR 2019 Workshop --weighting GLS
Projecting Conflicting Gradient (PCGrad) NeurIPS 2020 --weighting PCGrad
Gradient sign Dropout (GradDrop) NeurIPS 2020 --weighting GradDrop
Impartial Multi-Task Learning (IMTL) ICLR 2021 --weighting IMTL
Gradient Vaccine (GradVac) ICLR 2021 --weighting GradVac
Conflict-Averse Gradient descent (CAGrad) (official code) NeurIPS 2021 --weighting CAGrad
Nash-MTL (official code) ICML 2022 --weighting Nash_MTL
Random Loss Weighting (RLW) TMLR 2022 --weighting RLW
MoCo ICLR 2023 --weighting MoCo
Aligned-MTL (official code) CVPR 2023 --weighting Aligned_MTL
STCH (official code) ICML 2024 --weighting STCH
ExcessMTL (official code) ICML 2024 --weighting ExcessMTL
DB-MTL arXiv --weighting DB_MTL
Architectures Venues Arguments
Hard Parameter Sharing (HPS) ICML 1993 --arch HPS
Cross-stitch Networks (Cross_stitch) CVPR 2016 --arch Cross_stitch
Multi-gate Mixture-of-Experts (MMoE) KDD 2018 --arch MMoE
Multi-Task Attention Network (MTAN) (official code) CVPR 2019 --arch MTAN
Customized Gate Control (CGC), Progressive Layered Extraction (PLE) ACM RecSys 2020 --arch CGC, --arch PLE
Learning to Branch (LTB) ICML 2020 --arch LTB
DSelect-k (official code) NeurIPS 2021 --arch DSelect_k

Supported Benchmark Datasets

Datasets Problems Task Number Tasks multi-input Supported Backbone
NYUv2 Scene Understanding 3 Semantic Segmentation+
Depth Estimation+
Surface Normal Prediction
ResNet50/
SegNet
Office-31 Image Recognition 3 Classification ResNet18
Office-Home Image Recognition 4 Classification ResNet18
QM9 Molecular Property Prediction 11 (default) Regression GNN
PAWS-X Paraphrase Identification 4 (default) Classification Bert

Installation

  1. Create a virtual environment

    conda create -n libmtl python=3.8
    conda activate libmtl
    pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
  2. Clone the repository

    git clone https://github.com/median-research-group/LibMTL.git
  3. Install LibMTL

    cd LibMTL
    pip install -r requirements.txt
    pip install -e .

Quick Start

We use the NYUv2 dataset as an example to show how to use LibMTL.

Download Dataset

The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here.

Run a Model

The complete training code for the NYUv2 dataset is provided in examples/nyu. The file main.py is the main file for training on the NYUv2 dataset.

You can find the command-line arguments by running the following command.

python main.py -h

For instance, running the following command will train an MTL model with EW and HPS on NYUv2 dataset.

python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATH

More details is represented in Docs.

Citation

If you find LibMTL useful for your research or development, please cite the following:

@article{lin2023libmtl,
  title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
  author={Baijiong Lin and Yu Zhang},
  journal={Journal of Machine Learning Research},
  volume={24},
  number={209},
  pages={1--7},
  year={2023}
}

Contributor

LibMTL is developed and maintained by Baijiong Lin.

Contact Us

If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to bj.lin.email@gmail.com.

Acknowledgements

We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, mtan, MTL, nash-mtl, pytorch_geometric, and xtreme.

License

LibMTL is released under the MIT license.