DonkeyShot21 / cassle

Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)
MIT License
116 stars 18 forks source link

Difficulty In Reproducing The Paper Results (e.g., Table 4., BYOL reported 66%, measured <<60%) #12

Open kilickaya opened 1 year ago

kilickaya commented 1 year ago

Hi,

Thanks a lot for your amazing work and releasing the code. I am trying to reproduce your Table 4 for sometime. I directly use the code and the scripts with NO modification.

For example, in this Table, BYOL fine-tuning on ImageNet-100 for 5-class incremental task performance is 66.0. Instead, I measured below <<60.0, at least 6% below. Please see the full results Table below if interested (a 5 x 5 Table).

results.pdf

Any idea what may be causing the gap? Is there any nuances in evaluation method? For example, for average accuracy, I simply take the mean of the below Table across all rows and colums (as also suggested by GEM, as you referenced).

Thanks a lot again for your response and your eye-opening work.

DonkeyShot21 commented 1 year ago

Hi, I have just checked my logs and the results seem consistent with the ones we published. See the screenshot below:

image

The second run (65.4%) is with slightly different hyperparams. We got around 59% with online linear eval and ~66% after offline linear eval. It might be that you are having some issues with the offline linear eval parameters. How much did you get with online eval? Maybe I can look for the checkpoint and you can try just running offline linear eval to debug?

I don't really understand the results you are reporting, you need to look for val_acc1 in the wandb run. Maybe you are not looking at the right metric?

kilickaya commented 1 year ago

Hi Enrico,

Many thanks for the swift response!

Please see the wandb output for val_acc1 on ImageNet-100, for all the 5 checkpoints

BYOL_Finetune_ImageNet100

As is evident, the last model (task4, the highest) reaches to 62% accuracy at the very end of the linear-probing.

Please see below my offline linear probing parameters, equivalent to yours:


python main_linear.py \
    --dataset imagenet100 \
    --encoder resnet18 \
    --data_dir $DATA_DIR \
    --train_dir imagenet-100/train \
    --val_dir imagenet-100/val \
    --split_strategy class \
    --num_tasks 5 \
    --max_epochs 100 \
    --gpus 0 \
    --precision 16 \
    --optimizer sgd \
    --scheduler step \
    --lr 3.0 \
    --lr_decay_steps 60 80 \
    --weight_decay 0 \
    --batch_size 256 \
    --num_workers 8 \
    --dali \
    --name byol-imagenet100-5T-linear-eval \
    --pretrained_feature_extractor $PRETRAINED_PATH \
    --project benchmark \
    --entity swordrock \
    --wandb \
    --save_checkpoint

Is the accuracy in the paper just the accuracy of the final model (which I found as 62%)?

It would be great if you can share the checkpoint indeed. Then I can debug my evaluation code.

It would be great if you can share the evaluation script as well. It does not have to be clean, just to give the clearest idea possible.

Thank you.

DonkeyShot21 commented 1 year ago

Yes, it is the accuracy of the final model as reported in the paper. Intermediate checkpoints are only used for forgetting. Please see this screenshot from the paper below:

image

In this case, since the number of samples per task is roughly constant, the average is the same as the simple linear eval accuracy.

I will look for the checkpoint and post it here asap.

I have some questions:

kilickaya commented 1 year ago

Thanks for your response.

Training

Will have a look at tuning parameters further. Thank you.

DonkeyShot21 commented 1 year ago

I found some checkpoints that might be relevant: https://drive.google.com/drive/folders/1gOejzl4Q0cqAcmEjUhyStYPDbXPn1o9R?usp=share_link You can find the pre-train args there as well.

I am not 100% sure that this is the correct checkpoint, so use it at your own risk.

EDIT: this checkpoint was probably obtained with a different version of the code, you might have issues resuming it

DonkeyShot21 commented 1 year ago

Yes, your curves look similar to mine. I think it is likely to be due to hyperparam tuning of the offline linear eval. Also, always remember that there might be some randomness involved, so a small decrease in performance might be due to that.

kilickaya commented 1 year ago

Thanks for the model, args and the info. I will have a look at these. Thanks!

DonkeyShot21 commented 1 year ago

One last thing that just came to my mind. We recently found that Pillow-SIMD can have a detrimental effect on some models (see the issue here https://github.com/vturrisi/solo-learn/issues/313). I am not sure if we used it or not in our experiments. Might be another thing to check on.

EDIT: also make sure you use Dali for pre-training.

kilickaya commented 1 year ago

Cool. I was using it, actually. Will try without it and report any difference.

kilickaya commented 1 year ago

Update-1: I've spent the day to perform hyper-param tuning for offline linear eval. I update here in case someone else wants to see the end result as well.

Tl;dr: I could not reach above 62% despite brute-force search, still much lower than 66%. So the conclusion is that it is not about the linear-probing stage, but the actual pre-training.

Setting: (I highlight the author's recommended setting from this code base, which yields the best accuracy I can get: 62%)

BYOL, fine-tuned for 400-epochs per-task
5-class incremental
ImageNet-100
Offline-eval via linear-probe
lr: {0.01,0.1,1,**3**,10, 100}
batch_size: {64, 128, **256**}
weight_decay: {**0**, 1e-5}

Results: Some will appear shorter as they have different batch size/different number of iters consequently. W B Chart 2_8_2023, 9_45_44 PM

To-do: I'll try without pillow-SIMD. Then, I'll focus on improving the pre-training part.

DonkeyShot21 commented 1 year ago

How much did you get with online linear eval?

kilickaya commented 1 year ago

Generally 4-5% below the offline counterpart.

DonkeyShot21 commented 1 year ago

Ok, so around 57. The checkpoint that I shared should have online eval accuracy 58.8%.