DonkeyShot21 / cassle

Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)
MIT License
114 stars 18 forks source link

KNN Classifier Issue #6

Open pp1016 opened 2 years ago

pp1016 commented 2 years ago

Hi, I found this work very interesting and plan to work on similar topics. However I encounter some issues: (1) For the fine-tuning example with Barlow Twins and CIFAR-100, should it be barlow.sh instead of barlow_distill.sh? Otherwise, we need to provide the pretrained model in order to successfully run the code. (2) If I enable the the KNN online evaluation by setting disable_knn_eval = False, there was an issue showing empty test feature and expect argument in base.py line 432. I saw the previous closed issue saying the similar thing but it still appears even if I set a meaningful online_eval_batch_size = 256. Thanks for your help!

pp1016 commented 2 years ago

Btw, is the linear evaluation support for CIFAR dataset?

DonkeyShot21 commented 2 years ago

Hi,

1) yes it was a typo, I have just fixed it. 2) Ok, can you provide the script you used and the full output error?

Offline linear evaluation does not improve on CIFAR so we just use online linear eval.

pp1016 commented 2 years ago

Thanks for your reply!

Basically, when I run the example script with Fine-tune Barlow on CIFAR: DATA_DIR=/path/to/data/dir/ CUDA_VISIBLE_DEVICES=0 python job_launcher.py --script bash_files/continual/cifar/barlow.sh. It shows up an error "File "cassle/methods/base.py", line 432, in training_step self.knn( and File "cassle/utils/knn.py", line 89, in compute test_features = torch.cat(self.test_features) RuntimeError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat. This usually means that this function requires a non-empty list of Tensors." It seems that in base.py line 432, when we call self.knn(), only train_features/targets are specified and test_features/targets are missing, this cause the raised error. When I manually set self.disable_knn_eval = True, the code can run without error.

Thanks for your help!

vanity1129 commented 1 year ago

I encountered the same error "File "cassle/methods/base.py", line 432, in training_step self.knn( and File "cassle/utils/knn.py", line 89, in compute test_features = torch.cat(self.test_features)...", And I did the same thing with @pp1016 to set --disable_knn_eval in bash file "barlow.sh",and another error was encountered: Traceback (most recent call last): File "main_pretrain.py", line 224, in main() File "main_pretrain.py", line 220, in main trainer.fit(model, train_loaders, val_loader) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit self._call_and_handle_interrupt( File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train self.fit_loop.run() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance self.epoch_loop.run(data_fetcher) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance result = self._run_optimization( File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization self._optimizer_step(optimizer, opt_idx, batch_idx, closure) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step lightning_module.optimizer_step( File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1651, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, *kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step closure_result = closure() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call self._result = self.closure(args, kwargs) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure step_output = self._step_fn() File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 449, in _training_step check_finite_loss(result.closure_loss) File "/home/zw/anaconda3/envs/cassle/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 39, in check_finite_loss raise ValueError(f"The loss returned in training_step is {loss}.") ValueError: The loss returned in training_step is nan.

Is there any way to solve it? Thanks for your help! @DonkeyShot21

danielm1405 commented 1 year ago

I was able to solve the issue. In https://github.com/DonkeyShot21/cassle/blob/main/cassle/methods/base.py#L432 and https://github.com/DonkeyShot21/cassle/blob/main/cassle/methods/base.py#L464 there should be self.knn.update instead of self.knn. Calling metric() does metric.update() and then metric.compute() but in training_step and validation_step we only want to call update() as not all the data has been passed to the metric yet.

However, when I run the code with such fix the results are not as expected: for finetuning of BarlowTwins I get ~55% accuracy for online linear probe but only ~35% of knn accuraccy. According to the paper the difference should not be so high (ImageNet100 in Tab. 4 vs backbone eval in Tab. C).

Does anyone have a clue why the difference is so high? @DonkeyShot21 @vturrisi ?

DonkeyShot21 commented 1 year ago

Hi! I am not sure I understand the issue. Table C reports results using offline k-nn, while I think you are running online k-nn, am I right? Table 4 is linear evaluation, that is well known to be better than k-nn, especially with CNNs.

danielm1405 commented 1 year ago
  1. What is the difference between online and offline k-nn? k-nn is non-parametric so the results at the last epoch of online evaluation with k-nn should be almost the same as the results with offline k-nn, am I right? The only difference I can see is that in online scenario the data is strongly augmented (SSL augmentations) and for offline k-nn probably weak augmentations should be used (classification augmentations).
  2. I am aware of the fact that k-nn performs worse than the linear probe. The scale of the gap is what surprises me. For CIFAR-100 5 tasks split BarlowTwins finetuning the gap between those metrics is 20% (55% lin eval vs 35% knn). In the paper the gap is much smaller, eg. BarlowTwins finetuning on ImageNet100 achieves 63.1% in lin eval and 59.1% in k-nn so the gap is only 4%. I wonder why do I get such big difference between these metrics. Maybe it is because of the k-nn hyperparameters? Did you use the default parameters of k-nn to obtain the results from the paper?
DonkeyShot21 commented 1 year ago

There are two main differences: (a) the features are collected with different parameter configs, (b) as you said, augmentations. While (a) should not be too bad at the end of training where the lr is very low, (b) is still a big issue. Normally people use no augmentations for knn afaik. Also remember that, especially on CIFAR, these augmentations are very strong and they corrupt images completely, in some cases they crop up to 3px x 3px. We never changed the hyperparams, so those that you linked should be correct.