ynanwu / MetaGCD

10 stars 1 forks source link

Learning to Continually Learn in Generalized Category Discovery

This repository is the official implementation of MetaGCD: Learning to Continually Learn in Generalized Category Discovery.

Method Overview

Requirements

The code was tested on python3.6 pytorch1.4.0 and CUDA9.2.

We recommend using conda environment to setup all required dependencies:

conda env create -f environment.yml
conda activate MetaGCD

If you have any problem with the above command, you can also install them by pip install -r requirements.txt.

Offline Train

We provide the training script for the following 3 datasets from the NCD benchmark: CIFAR10 , CIFAR100 and Tiny-ImageNet. To train the models in the paper, run the following commands:

python methods/contrastive_training/contrastive_learning_based_MAML.py --run_mode 'MetaTrain' --dataset_name <dataset>

Set paths to datasets, pre-trained models and desired log directories in config.py

Online Incremental Learning

To evaluate meta-trained models, run:

python methods/contrastive_training/contrastive_learning_based_MAML.py --run_mode 'MetaTest' --dataset_name <dataset>

Citation

If you find this codebase useful in your research, consider citing:

@inproceedings{
    wu2023metagcd,
    title={MetaGCD: Learning to Continually Learn in Generalized Category Discovery},
    author={Yanan Wu and Zhixiang Chi and Yang Wang and and Songhe Feng},
    booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    year={2023}
}