Dmitriy Senushkin, Nikolay Patakin, Arsneii Kuznetsov, Anton Konushin
Paper | Supplementary | Arxiv | Video
We introduce an Aligned-MTL - gradient optimization method for training multi-task neural networks. We address multi-task optimization issues by examining the stability of a gradient system as measured by its condition number. Using this criterion as a design principle, we present an Aligned-MTL method that eliminates instability in the training process by aligning the principal components of a gradient system.
This is an official Pytorch implementation of the paper
Besides torch
installation, please ensure that the following libraries are installed:
pip install scipy cvxpy matplotlib seaborn tqdm
1) For NYU (3 task) benchmark please download data from this link (~9GB, provided by MTAN repository).
Unzip and specify --data-path
path to a folder containing train
and val
subfolders.
2) For Cityscapes (2 task) benchmark: link.
Unzip and specify --data-path
path to a folder containing train
and val
subfolders.
3) For Cityscapes (3 task) benchmark please refer to the original Cityscapes benchmark website and download the following files:
leftImg8bit_trainvaltest.zip (11.6 GB)
Unzip them into the same folder (as a result, should contain leftImg8bit
, disparity
and gtFine
)
We provide implementations of several multi-task optimization methods. Gradient optimization method can
be used during training by setting --balancer
parameter to one of the following values:
Parameter value | Method | Conference |
---|---|---|
ls |
Uniform weighting (aka linear scalarization) | - |
uncertainty |
Uncertainty weghting | CVPR 2018 |
gradnorm |
Gradient normalization | ICML 2018 |
mgda , mgdaub |
MGDA and MGDA-UB | NeurIPS 2018 |
dwa |
Dynamic Weight Average | CVPR 2019 |
pcgrad |
Projecting Conflicting Gradients | NeurIPS 2020 |
graddrop |
Gradient Sign Dropout | NeurIPS 2020 |
imtl |
Impartial Multi-Task Learning | ICLR 2021 |
gradvac |
Gradient Vaccine | ICLR 2021 |
cagrad |
Conflict-Averse Gradient descent | NeurIPS 2021 |
nash |
Nash-MTL | ICML 2022 |
rlw |
Random Loss Weighting with normal distribution | TMLR 2022 |
amtl , amtlub |
Ours, Aligned-MTL and Aligned-MTL-UB | CVPR 2023 |
Our repository provides four benchmarks:
cityscapes_pspnet
Cityscapes 3-task. PSPNet model. Semantic segmentation + Instance segmentation + Depth estimationcityscapes_mtan
Cityscapes 2-task. MTAN model. Semantic segmentation + Depth estimationnyuv2_pspnet
NYUv2 3-task. PSPNet model. Semantic segmentation + Depth estimation + Surface Normal Estimationnyuv2_mtan
NYUv2 3-task. MTAN model. Semantic segmentation + Depth estimation + Surface Normal EstimationTo run the training and evaluation code use (change parameter values according to your needs):
python train.py --benchmark cityscapes_pspnet --balancer amtl --data-path /path/to/cityscapes
To compare methods on a synthetic two-task benchmark use:
python optimize_toy.py --balancer amtl --scale 0.5
Scale parameter controls the balance between problems, i.e. L0 = scale L1 + (1-scale) L2.
Our synthetic benchmark and visualisation code is based on the NashMTL repository.
If you find our code repository or paper useful, please cite us:
@InProceedings{Senushkin_2023_CVPR,
author = {Senushkin, Dmitry and Patakin, Nikolay and Kuznetsov, Arseny and Konushin, Anton},
title = {Independent Component Alignment for Multi-Task Learning},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {20083-20093}
}