Official PyTorch implementation of Domain Generalization by Mutual-Information Regularization with Pre-trained Models.
Junbum Cha, Kyungjae Lee, Sungrae Park, Sanghyuk Chun.
pip install -r requirements.txt
python -m domainbed.scripts.download --data_dir=/my/datasets/path
Environment details used for the main experiments. Every main experiment is conducted on a single NVIDIA V100 GPU.
Environment:
Python: 3.7.7
PyTorch: 1.7.1
Torchvision: 0.8.2
CUDA: 10.1
CUDNN: 7603
NumPy: 1.21.4
PIL: 7.2.0
train_all.py
script conducts multiple leave-one-out cross-validations for all target domain.
python train_all.py exp_name --dataset PACS --data_dir /my/dataset/path --algorithm MIRO
Run command with hyperparameters (HPs):
python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
--dataset PACS \
--lr 3e-5 \
--resnet_dropout 0.0 \
--weight_decay 0.0 \
--ld 0.01 \
--trial_seed 0
Our searched HPs:
PACS | VLCS | OfficeHome | TerraIncognita | DomainNet | |
---|---|---|---|---|---|
Learning rate | 3e-5 | 1e-5 | 3e-5 | 3e-5 | 3e-5 |
Dropout | 0.0 | 0.5 | 0.1 | 0.0 | 0.1 |
Weight decay | 0.0 | 1e-6 | 1e-6 | 1e-4 | 0.0 |
$\lambda$ | 0.01 | 0.01 | 0.1 | 0.1 | 0.1 |
Set --swad True
to combine with SWAD.
python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
--dataset PACS \
--ld 0.01 \
--swad True \
--trial_seed 0
You can run MIRO with different backbones via --model
parameter:
# model is one of [resnet50, resnet50_barlowtwins, resnet50_moco, clip_resnet50, clip_vit-b16, swag_regnety_16gf]
python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
--dataset PACS --model resnet50
The checkpoint should be prepared before run MoCo v3 (resnet50_moco
).
You can download ResNet-50 MoCo v3 on 1000 epochs here.
We provide the commands to reproduce the main results of the paper (Table 1).
Note that every result is averaged over three trials; use trial_seed
option as 0, 1, 2 and average the results.
PACS | VLCS | OfficeHome | TerraIncognita | DomainNet | Avg. | |
---|---|---|---|---|---|---|
MIRO + SWAD (B=16) | 96.8 | 81.7 | 83.3 | 64.3 | 60.7 | 77.3 |
MIRO + SWAD (B=32) | 97.5 | 81.6 | 84.2 | 65.5 | 60.8 | 77.9 |
In pushing the limits experiments of the paper, we use the batch size of 16 due to the limitation of computational resources. Afterwards, we find that the batch size of 32 improves the overall performance as shown in this table. Note that this experiments tune $\lambda$ only, so there is room for further performance improvement by intensive HP tuning.
@article{cha2022miro,
title={Domain Generalization by Mutual-Information Regularization with Pre-trained Models},
author={Junbum Cha and Kyungjae Lee and Sungrae Park and Sanghyuk Chun},
journal={European Conference on Computer Vision (ECCV)},
year={2022}
}
This project is released under the MIT license, included here.
This project include some code from facebookresearch/DomainBed (MIT license) and khanrc/swad (MIT license).