linhaowei1 / TPL

✌[ICLR 2024] Class Incremental Learning via Likelihood Ratio Based Task Prediction
https://arxiv.org/abs/2309.15048
24 stars 0 forks source link

Unable to reproduce the results in orginal paper. #2

Closed ZDYoung0519 closed 2 hours ago

ZDYoung0519 commented 1 month ago

I am trying to reproduce CIFAR100-10T, but get much lower accuracy: The results of first two task. {0: {'tp_acc': 0.681, 'acc': 0.3}, 1: {'tp_acc': 0.651, 'acc': 0.259}, 'auroc': 0.57520375, 'fpr@95': 0.9095, 'aupr': 0.5692886335071314}

I found that vit_hat cannot learn each task correctly (til acc is low). What's the problem?

linhaowei1 commented 1 month ago

Hello,

Can I kindly ask which script or what commands did you run? Let me re-run it and check if there is something wrong in my code.

ZDYoung0519 commented 1 month ago

Hello,

Can I kindly ask which script or what commands did you run? Let me re-run it and check if there is something wrong in my code.

The performance on task 0 is bad in my reproduction.

For training task 0, I use main.py with following parameters:

--task 0 --idrandom 0 --visual_encoder deit_small_patch16_224_in661 --baseline deit_small_patch16_224_in661_C100_10T_hat --seed 0 --batch_size 64 --sequence_file C100_10T --learning_rate 0.001 --num_train_epochs 40 --base_dir ckpt --class_order 0 --latent 128 --replay_buffer_size 2000 --training

ZDYoung0519 commented 1 month ago

I am trying to reproduce CIFAR100-10T, but get much lower accuracy: The results of first two task. {0: {'tp_acc': 0.681, 'acc': 0.3}, 1: {'tp_acc': 0.651, 'acc': 0.259}, 'auroc': 0.57520375, 'fpr@95': 0.9095, 'aupr': 0.5692886335071314}

I found that vit_hat cannot learn each task correctly (til acc is low). What's the problem?

The results here are obtained with eval.py after training task 0 and 1, in the same way as the provided scripts.

ZDYoung0519 commented 1 month ago

I also tried reproduce the CIFAR10-5T baseline. For the first 2 tasks, i got : task 0, til_acc = 0.779, cil_acc = 0.619, tp_acc = 0.7505 task 1, til_acc = 0.7265, cil_acc = 0.34, tp_acc = 0.4235 (the above results are obtained after training 2, with default ood method) after running the eval.py, I got:

{0: {'tp_acc': 0.683, 'acc': 0.5445}, 1: {'tp_acc': 0.8565, 'acc': 0.6315}, 'auroc': 0.786702, 'fpr@95': 0.67475, 'aupr': 0.7746352374804507}

linhaowei1 commented 1 month ago

Thank you for pointing this out! The low acc is because of not loading the correct checkpoint. We found a bug in the code of loading pre-trained checkpoint. As we integrate many other pre-trained models in our codebase, we forgot that the checkpoint formats are different between MORE's and the ones downloded from timm.

We load the checkpoint using the following command in utils/utils.py:

transfer = {k: v for k, v in checkpoint.items() if k in target and 'head' not in k}

However, our checkpoint downloaded from MORE not only contains the model parameters but also other information like optimizer states. This will result in an empty transfer dict. To make it correct, we need to use checkpoint['model'] for checkpoint loading. We fix this bug by adding the following code before the definition of transfer:

if 'model' in checkpoint.keys():
        checkpoint = checkpoint['model']

The new code is pushed. You can try it out. If there are further questions, don't hesitate to contact me.

Thanks again for following our paper!

ZDYoung0519 commented 1 month ago

I also tried reproduce the CIFAR10-5T baseline. The ideal accuracy would be around 90%+, but i got only <60% accuracy.

For the first 2 tasks, i got : task 0, til_acc = 0.779, cil_acc = 0.619, tp_acc = 0.7505 task 1, til_acc = 0.7265, cil_acc = 0.34, tp_acc = 0.4235 (the above results are obtained after training 2, with default ood method) after running the eval.py, I got:

{0: {'tp_acc': 0.683, 'acc': 0.5445}, 1: {'tp_acc': 0.8565, 'acc': 0.6315}, 'auroc': 0.786702, 'fpr@95': 0.67475, 'aupr': 0.7746352374804507}

ZDYoung0519 commented 1 month ago

Thank you for pointing this out! The low acc is because of not loading the correct checkpoint. We found a bug in the code of loading pre-trained checkpoint. As we integrate many other pre-trained models in our codebase, we forgot that the checkpoint formats are different between MORE's and the ones downloded from timm.

We load the checkpoint using the following command in utils/utils.py:

transfer = {k: v for k, v in checkpoint.items() if k in target and 'head' not in k}

However, our checkpoint downloaded from MORE not only contains the model parameters but also other information like optimizer states. This will result in an empty transfer dict. To make it correct, we need to use checkpoint['model'] for checkpoint loading. We fix this bug by adding the following code before the definition of transfer:

if 'model' in checkpoint.keys():
        checkpoint = checkpoint['model']

The new code is pushed. You can try it out. If there are further questions, don't hesitate to contact me.

Thanks again for following our paper!

I will try this later. Thanks for your timely responses!

linhaowei1 commented 1 month ago

Thank you for pointing this out! The low acc is because of not loading the correct checkpoint. We found a bug in the code of loading pre-trained checkpoint. As we integrate many other pre-trained models in our codebase, we forgot that the checkpoint formats are different between MORE's and the ones downloded from timm. We load the checkpoint using the following command in utils/utils.py:

transfer = {k: v for k, v in checkpoint.items() if k in target and 'head' not in k}

However, our checkpoint downloaded from MORE not only contains the model parameters but also other information like optimizer states. This will result in an empty transfer dict. To make it correct, we need to use checkpoint['model'] for checkpoint loading. We fix this bug by adding the following code before the definition of transfer:

if 'model' in checkpoint.keys():
        checkpoint = checkpoint['model']

The new code is pushed. You can try it out. If there are further questions, don't hesitate to contact me. Thanks again for following our paper!

I will try this later. Thanks for your timely responses!

You're welcome:) Looking forward to your good news.