SakurajimaMaiii / TSD

[CVPR 2023] Feature Alignment and Uniformity for Test Time Adaptation
https://arxiv.org/abs/2303.10902
MIT License
41 stars 1 forks source link
computer-vision deep-learning machine-learning pytorch test-time-adaptation

Feature Alignment and Uniformity for Test Time Adaptation

This repo is officical PyTorch implement of Feature Alignment and Uniformity for Test Time Adaptation (CVPR 2023) by Shuai Wang, Daoan Zhang, Zipei Yan, Jianguo Zhang and Rui Li.
This paper could be found at arXiv, open access and IEEEXplore.
This codebase is mainly based on T3A and DeepDG.

💻 Dependence

We use python==3.8.13, other packages including:

torch==1.12.0
torchvision==0.13.0
numpy==1.22.3
tqdm==4.65.0
timm==0.6.12
scikit-learn==1.2.2 
pillow==9.0.1

If you want to use efficientnet, please confirm torchvision>=0.11.0.

Dataset

Download datasets used in our paper from:
PACS
OfficeHome
VLCS
DomainNet
Download them from the above links, and organize them as follows.

|-your_data_dir
  |-PACS
    |-art_painting
    |-cartoon
    |-photo
    |-sketch
  |-OfficeHome
    |-Art
    |-Clipart
    |-Product
    |-RealWorld
  |-VLCS
    |-Caltech101
    |-LabelMe
    |-SUN09
    |-VOC2007
  |-DomainNet
    |-clipart
    |-infograph
    |-painting
    |-quickdraw
    |-real
    |-sketch

Train source model

Please use train.py to train the source model. For example:

cd code/
python train.py --dataset PACS \
                --data_dir your_data_dir \
                --opt_type Adam \
                --lr 5e-5 \
                --max_epoch 50

Change --dataset PACS for other datasets, such as office-home, VLCS, DomainNet.
Set --net to use different backbones, such as resnext50, ViT-B16.

Test time adaptation

python unsupervise_adapt.py --dataset PACS \
                            --data_dir your_data_dir \
                            --adapt_alg TSD \ 
                            --pretrain_dir your_pretrain_model_dir \
                            --lr 1e-4

Change --adapt_alg TSD to use different methods of test time adaptation, e.g. T3A, SHOT-IM, Tent.
--pretrain_dir denotes the path of source model, e.g. ./train_outputs/model.pkl.
Empirically, set --lr to 1e-4 or 1e-5 achieves good performance. You can also search it using training domain validation set.

📝 Citation

If this repo is useful for your research, please consider citing our paper:

@inproceedings{wang2023feature,
  title={Feature alignment and uniformity for test time adaptation},
  author={Wang, Shuai and Zhang, Daoan and Yan, Zipei and Zhang, Jianguo and Li, Rui},
  booktitle={CVPR},
  year={2023}
}

✉️ Contact

Please contact bit.ybws@gmail.com