deeplearning-wisc / hypo

14 stars 0 forks source link

HYPO: Hyperspherical Out-of-Distribution Generalization

This codebase provides a Pytorch implementation for the paper: HYPO: Hyperspherical Out-of-Distribution Generalization by Haoyue Bai, Yifei Ming, Julian Katz-Samuels, and Yixuan Li, .

Remarks: The current codebase is available for preview purposes only and is still under development. We are actively working on eliminating hard-coded links, removing unused arguments, and streamlining the processes for loading models and datasets. Please stay tuned for forthcoming updates.

Abstract

Out-of-distribution (OOD) generalization is critical for machine learning models deployed in the real world. However, achieving this can be fundamentally challenging, as it requires the ability to learn invariant features across different domains or environments. In this paper, we propose a novel framework HYPO (HYPerspherical OOD generalization) that provably learns domain-invariant representations in a hyperspherical space. In particular, our hyperspherical learning algorithm is guided by intra-class variation and inter-class separation principles---ensuring that features from the same class (across different training domains) are closely aligned with their class prototypes, while different class prototypes are maximally separated. We further provide theoretical justifications on how our prototypical learning objective improves the OOD generalization bound. Through extensive experiments on challenging OOD benchmarks, we demonstrate that our approach outperforms competitive baselines and achieves superior performance.

Quick Start

Data Preparation

In this work, we evaluate the OOD generalization performance over a range of environmental discrepancies such as domains, image corruptions, and perturbations.

OOD generalization across domains: The default root directory for ID and OOD datasets is datasets/. We consider PACS, Office-Home, VLCS, Terra Incognita. You may use scripts/download.py (from DomainBed) to download and prepare the datasets for domain generalization.

OOD generalization across common corruptions: The default root directory for ID and OOD datasets is datasets/. We consider CIFAR-10 & CIFAR-10-C and ImageNet-100 & ImageNet-100-C. In alignment with prior works on the ImageNet-100 subset, the script for generating the subset is provided here.

CIFAR-10 & CIFAR-10-C

ImageNet-100 & ImageNet-100-C

Training and Evaluation

Model Checkpoints

Evaluate pre-trained checkpoints

Our checkpoints can be downloaded here. Create a directory named checkpoints/[ID_DATASET] in the root directory of the project and put the downloaded checkpoints here. For example, for CIFAR-10 and PACS:

checkpoints/
---CIFAR-10/        
------checkpoint_hypo_resnet18_cifar10.pth.tar
---PACS/        
------checkpoint_hypo_resnet50_td_photo.pth.tar
------checkpoint_hypo_resnet50_td_cartoon.pth.tar
------checkpoint_hypo_resnet50_td_sketch.pth.tar
------checkpoint_hypo_resnet50_td_art_painting.pth.tar

The following scripts can be used to evaluate the OOD detection performance:

sh scripts/eval_ckpt_cifar10.sh ckpt_c10 #for CIFAR-10
sh scripts/eval_ckpt_pacs.sh ckpt_pacs # for PACS

Evaluate custom checkpoints

If the default directory to save checkpoints is not checkpoints, create a softlink to the directory where the actual checkpoints are saved and name it as checkpoints. For example, checkpoints for CIFAR-100 (ID) are structured as follows:

checkpoints/
---CIFAR-100/
------name_of_ckpt/
---------checkpoint_500.pth.tar

Train from scratch

We provide sample scripts to train from scratch. Feel free to modify the hyperparameters and training configurations.

sh scripts/train_hypo_cifar10.sh
sh scripts/train_hypo_dg.sh

Fine-tune from ImageNet pre-trained models

We also provide fine-tuning scripts on large-scale datasets such as ImageNet-100.

sh scripts/train_hypo_imgnet100.sh 

Citation

If you find our work useful, please consider citing our paper:

@inproceedings{
hypo2024,
title={Provable Out-of-Distribution Generalization in Hypersphere},
author={Haoyue Bai and Yifei Ming and Julian Katz-Samuels and Yixuan Li},
booktitle={The Twelfth International Conference on Learning Representations (ICLR)},
year={2024},
}

Further discussions

For more discussions on the method and extensions, feel free to drop an email at hbai39@wisc.edu or ming5@wisc.edu