SamsungLabs / MTL

MIT License
74 stars 3 forks source link

[CVPR 2023] Independent Component Alignment for Multi-Task Learning


Preview animation

Dmitriy Senushkin, Nikolay Patakin, Arsneii Kuznetsov, Anton Konushin

Paper | Supplementary | Arxiv | Video

We introduce an Aligned-MTL - gradient optimization method for training multi-task neural networks. We address multi-task optimization issues by examining the stability of a gradient system as measured by its condition number. Using this criterion as a design principle, we present an Aligned-MTL method that eliminates instability in the training process by aligning the principal components of a gradient system.

This is an official Pytorch implementation of the paper

Installation

Besides torch installation, please ensure that the following libraries are installed:

pip install scipy cvxpy matplotlib seaborn tqdm

Data prepration

1) For NYU (3 task) benchmark please download data from this link (~9GB, provided by MTAN repository). Unzip and specify --data-path path to a folder containing train and val subfolders.

2) For Cityscapes (2 task) benchmark: link. Unzip and specify --data-path path to a folder containing train and val subfolders.

3) For Cityscapes (3 task) benchmark please refer to the original Cityscapes benchmark website and download the following files:

MTL methods

We provide implementations of several multi-task optimization methods. Gradient optimization method can be used during training by setting --balancer parameter to one of the following values:

Parameter value Method Conference
ls Uniform weighting (aka linear scalarization) -
uncertainty Uncertainty weghting CVPR 2018
gradnorm Gradient normalization ICML 2018
mgda, mgdaub MGDA and MGDA-UB NeurIPS 2018
dwa Dynamic Weight Average CVPR 2019
pcgrad Projecting Conflicting Gradients NeurIPS 2020
graddrop Gradient Sign Dropout NeurIPS 2020
imtl Impartial Multi-Task Learning ICLR 2021
gradvac Gradient Vaccine ICLR 2021
cagrad Conflict-Averse Gradient descent NeurIPS 2021
nash Nash-MTL ICML 2022
rlw Random Loss Weighting with normal distribution TMLR 2022
amtl, amtlub Ours, Aligned-MTL and Aligned-MTL-UB CVPR 2023

Benchmarks

Our repository provides four benchmarks:

  1. cityscapes_pspnet Cityscapes 3-task. PSPNet model. Semantic segmentation + Instance segmentation + Depth estimation
  2. cityscapes_mtan Cityscapes 2-task. MTAN model. Semantic segmentation + Depth estimation
  3. nyuv2_pspnet NYUv2 3-task. PSPNet model. Semantic segmentation + Depth estimation + Surface Normal Estimation
  4. nyuv2_mtan NYUv2 3-task. MTAN model. Semantic segmentation + Depth estimation + Surface Normal Estimation

To run the training and evaluation code use (change parameter values according to your needs):

python train.py --benchmark cityscapes_pspnet --balancer amtl --data-path /path/to/cityscapes

To compare methods on a synthetic two-task benchmark use:

python optimize_toy.py --balancer amtl --scale 0.5

Scale parameter controls the balance between problems, i.e. L0 = scale L1 + (1-scale) L2.

Links

Our synthetic benchmark and visualisation code is based on the NashMTL repository.

Citation

If you find our code repository or paper useful, please cite us:

@InProceedings{Senushkin_2023_CVPR,
    author    = {Senushkin, Dmitry and Patakin, Nikolay and Kuznetsov, Arseny and Konushin, Anton},
    title     = {Independent Component Alignment for Multi-Task Learning},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {20083-20093}
}