roymiles / Simple-Recipe-Distillation

[AAAI 2024] Understanding the Role of the Projector in Knowledge Distillation
13 stars 0 forks source link

Random projector? #1

Closed yoshitomo-matsubara closed 5 months ago

yoshitomo-matsubara commented 7 months ago

Hi @roymiles

Congratulations for the paper acceptance!

For the ImageNet experiment, self.embed in OurDistillationLoss class looks like a random projector for ResNet-18's embeddings and seems not updated as it's not included in optimizer. Is it intentional? If so, why is it required? https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py#L140

roymiles commented 7 months ago

Hi Yoshitomo!,

Thank you :) and yes that must have been an oversight when re-writing the code and putting it in this repo. Unfortunately, I only put time into checking the deit/ code and experiments before putting this up. I will look into fixing the resnet18 imagenet code this week and re-running the experiments. Thanks for spotting this! I will reply back on this issue then.

roymiles commented 7 months ago

I have fixed the the repo now. The line of interest is here: https://github.com/roymiles/Simple-Recipe-Distillation/blob/dc7667f12207683558d1eaf400205d1b313acde1/imagenet/torchdistill/core/distillation.py#L182

It is a bit slow training with just a single V100, so I figured I would give an update on the progress. Both runs are about 40% done and, as expected, the trainable projector leads to much better performance. I will upload the complete and final logs and also the checkpoints when they are both done.

Frozen Projector

2024-02-08 09:28:26,407 INFO torchdistill.misc.log Epoch: [40] [4900/5005] eta: 0:00:36 lr: 0.010000000000000002 img/s: 749.388497121768 loss: 1.6882 (1.6121) time: 0.3448 data: 0.0004 max mem: 9402 2024-02-08 09:28:43,604 INFO torchdistill.misc.log Epoch: [40] [4950/5005] eta: 0:00:18 lr: 0.010000000000000002 img/s: 754.2297413578135 loss: 1.6186 (1.6125) time: 0.3424 data: 0.0004 max mem: 9402 2024-02-08 09:29:00,909 INFO torchdistill.misc.log Epoch: [40] [5000/5005] eta: 0:00:01 lr: 0.010000000000000002 img/s: 741.8443645307605 loss: 1.5608 (1.6128) time: 0.3438 data: 0.0001 max mem: 9402 2024-02-08 09:29:02,451 INFO torchdistill.misc.log Epoch: [40] Total time: 0:28:48 2024-02-08 09:29:05,346 INFO torchdistill.misc.log Validation: [ 0/391] eta: 0:18:50 acc1: 85.9375 (85.9375) acc5: 94.5312 (94.5312) time: 2.8921 data: 2.8425 max mem: 9402 2024-02-08 09:29:08,957 INFO torchdistill.misc.log Validation: [ 50/391] eta: 0:00:43 acc1: 76.5625 (72.0435) acc5: 92.1875 (90.3799) time: 0.0847 data: 0.0391 max mem: 9402 2024-02-08 09:29:12,496 INFO torchdistill.misc.log Validation: [100/391] eta: 0:00:28 acc1: 71.0938 (71.5424) acc5: 93.7500 (91.6538) time: 0.0664 data: 0.0211 max mem: 9402 2024-02-08 09:29:16,595 INFO torchdistill.misc.log Validation: [150/391] eta: 0:00:22 acc1: 67.1875 (71.2955) acc5: 91.4062 (91.7736) time: 0.0739 data: 0.0278 max mem: 9402 2024-02-08 09:29:20,473 INFO torchdistill.misc.log Validation: [200/391] eta: 0:00:17 acc1: 51.5625 (68.2369) acc5: 79.6875 (89.7466) time: 0.0681 data: 0.0226 max mem: 9402 2024-02-08 09:29:23,996 INFO torchdistill.misc.log Validation: [250/391] eta: 0:00:12 acc1: 63.2812 (66.6802) acc5: 84.3750 (88.4120) time: 0.0645 data: 0.0190 max mem: 9402 2024-02-08 09:29:27,650 INFO torchdistill.misc.log Validation: [300/391] eta: 0:00:07 acc1: 58.5938 (65.0722) acc5: 81.2500 (87.1366) time: 0.0624 data: 0.0163 max mem: 9402 2024-02-08 09:29:31,231 INFO torchdistill.misc.log Validation: [350/391] eta: 0:00:03 acc1: 57.0312 (63.9312) acc5: 83.5938 (86.2669) time: 0.0761 data: 0.0305 max mem: 9402 2024-02-08 09:29:34,561 INFO torchdistill.misc.log Validation: Total time: 0:00:32 2024-02-08 09:29:34,561 INFO main * Acc@1 63.9700 Acc@5 86.2680

Trainable Projector

2024-02-08 08:57:20,154 INFO torchdistill.misc.log Epoch: [39] [4750/5005] eta: 0:01:27 lr: 0.010000000000000002 img/s: 746.3973317873493 loss: -0.3715 (-0.4433) time: 0.3440 data: 0.0003 max mem: 9402 2024-02-08 08:57:37,386 INFO torchdistill.misc.log Epoch: [39] [4800/5005] eta: 0:01:10 lr: 0.010000000000000002 img/s: 746.8458860041928 loss: -0.4441 (-0.4437) time: 0.3449 data: 0.0004 max mem: 9402 2024-02-08 08:57:54,584 INFO torchdistill.misc.log Epoch: [39] [4850/5005] eta: 0:00:53 lr: 0.010000000000000002 img/s: 748.0040237523999 loss: -0.4193 (-0.4434) time: 0.3434 data: 0.0004 max mem: 9402 2024-02-08 08:58:11,780 INFO torchdistill.misc.log Epoch: [39] [4900/5005] eta: 0:00:36 lr: 0.010000000000000002 img/s: 747.6076887031876 loss: -0.4346 (-0.4434) time: 0.3439 data: 0.0003 max mem: 9402 2024-02-08 08:58:29,019 INFO torchdistill.misc.log Epoch: [39] [4950/5005] eta: 0:00:18 lr: 0.010000000000000002 img/s: 747.4021349935126 loss: -0.4108 (-0.4430) time: 0.3452 data: 0.0003 max mem: 9402 2024-02-08 08:58:46,330 INFO torchdistill.misc.log Epoch: [39] [5000/5005] eta: 0:00:01 lr: 0.010000000000000002 img/s: 747.1722569610195 loss: -0.4533 (-0.4429) time: 0.3436 data: 0.0001 max mem: 9402 2024-02-08 08:58:47,852 INFO torchdistill.misc.log Epoch: [39] Total time: 0:28:46 2024-02-08 08:58:50,845 INFO torchdistill.misc.log Validation: [ 0/391] eta: 0:19:28 acc1: 79.6875 (79.6875) acc5: 92.9688 (92.9688) time: 2.9890 data: 2.9408 max mem: 9402 2024-02-08 08:58:54,453 INFO torchdistill.misc.log Validation: [ 50/391] eta: 0:00:44 acc1: 77.3438 (75.5974) acc5: 93.7500 (92.4173) time: 0.0866 data: 0.0409 max mem: 9402 2024-02-08 08:58:57,784 INFO torchdistill.misc.log Validation: [100/391] eta: 0:00:28 acc1: 75.7812 (74.8762) acc5: 94.5312 (93.1235) time: 0.0662 data: 0.0199 max mem: 9402 2024-02-08 08:59:01,960 INFO torchdistill.misc.log Validation: [150/391] eta: 0:00:22 acc1: 71.8750 (74.9638) acc5: 92.9688 (93.3775) time: 0.0862 data: 0.0403 max mem: 9402 2024-02-08 08:59:05,740 INFO torchdistill.misc.log Validation: [200/391] eta: 0:00:16 acc1: 56.2500 (72.1782) acc5: 82.0312 (91.5384) time: 0.0655 data: 0.0195 max mem: 9402 2024-02-08 08:59:09,170 INFO torchdistill.misc.log Validation: [250/391] eta: 0:00:11 acc1: 67.1875 (70.7483) acc5: 85.9375 (90.2920) time: 0.0607 data: 0.0150 max mem: 9402 2024-02-08 08:59:12,765 INFO torchdistill.misc.log Validation: [300/391] eta: 0:00:07 acc1: 63.2812 (69.5235) acc5: 82.8125 (89.3766) time: 0.0613 data: 0.0156 max mem: 9402 2024-02-08 08:59:16,248 INFO torchdistill.misc.log Validation: [350/391] eta: 0:00:03 acc1: 66.4062 (68.4606) acc5: 87.5000 (88.6040) time: 0.0699 data: 0.0237 max mem: 9402 2024-02-08 08:59:19,536 INFO torchdistill.misc.log Validation: Total time: 0:00:31 2024-02-08 08:59:19,537 INFO main * Acc@1 68.4660 Acc@5 88.6820

yoshitomo-matsubara commented 7 months ago

Hi @roymiles

Thanks for the update! Let me know once you finalize the code and config. I'm reimplementing your method in a unified way used for torchdistill, and if you're interested in contributing to torchdistill like this, I can help you do that and advertise your work

roymiles commented 6 months ago

That sounds great thanks! I'll just do a small loop through some hyperparameters first to get the best results. A proper implementation in torchdistill would be really great, thank you! I'll let you know when I have the results and code to share.

roymiles commented 6 months ago

@yoshitomo-matsubara I have just pushed now and it should all be good. I have also put the logs and model checkpoints in the README.md.

yoshitomo-matsubara commented 6 months ago

@roymiles Great! Do you want to make a PR for torchdistill repo? Or I can do it for you

roymiles commented 6 months ago

@yoshitomo-matsubara Yea I think it would be really great to have this as part of the torchdistill repo. I do realise my current implementation doesn't quite fit the torchdistill template and I would have to look through all the other implementations you have to see how it should be done properly.

If it is not too much work, it would be really nice if you could implement this in torchdistill for me. That would be really helpful, thanks! 🙂

yoshitomo-matsubara commented 6 months ago

No problem, I can do it for you.

What name would you pick for your method? If you don't have any preference, I would use MilesMikolajczyk2024 as part of module name e.g., OurDistillationLoss -> MilesMikolajczyk2024Loss

roymiles commented 6 months ago

Thanks! I think I'd prefer BNLogSumLoss as that summarises what it does a bit more clearly.

yoshitomo-matsubara commented 6 months ago

BNLogSum may be good for the loss module, but I need to add a wrapper to include an auxiliary trainable module for your method So I want to use a unique name of the method as part of the wrapper class name https://github.com/yoshitomo-matsubara/torchdistill/blob/main/torchdistill/models/wrapper.py#L210

Similarly, I need the name at other places as well e.g., https://github.com/yoshitomo-matsubara/torchdistill/tree/main/configs/sample/ilsvrc2012

roymiles commented 6 months ago

For the wrapper class, perhaps Linear4BNLogSum. For the folder I'm not too sure ha. I think log_sum is unique but you might have a better idea than me. It is quite hard to think of an acronym that describes all the components (linear + BN + log sum). I don't mind too much :) I'm cool with whatever works.

yoshitomo-matsubara commented 6 months ago

What about srd from this repository name (Simple-Recipe-Distillation) ?

roymiles commented 6 months ago

ah yea that sounds good 👍

yoshitomo-matsubara commented 6 months ago

Hi @roymiles

I added your method as SRD to torchdistill repo

Can you fork the current torchdistill repo and use this config to reproduce the number? resnet18_from_resnet34.txt (tentatively using .txt as .yaml file cannot be uploaded here)

Once you confirm the reproducibility, keep the log file and checkpoint file and submit a PR with the yaml file + README.md at configs/official/ilsvrc2012/roymiles/aaai2024/

roymiles commented 6 months ago

Hi @yoshitomo-matsubara

Thanks so much for doing this! I'll give this a go sometime this/next week once I have a few GPUs free.

yoshitomo-matsubara commented 6 months ago

@roymiles no problem! Let me know then I hope to include the official config, log, and checkpoint files when releasing the next version (soon)

EDIT: resnet18_from_resnet34.txt

yoshitomo-matsubara commented 6 months ago

Hi @roymiles

How's the experiment going?

roymiles commented 5 months ago

Hi @yoshitomo-matsubara

I am really sorry for the late reply. I had some issues before with training, though it seems going to DataParallel (as you suggested) has fixed it a lot more cleanly :D

I then got a bit bogged down with other work/personal events, but I have since started the run now and it seems to be training well. I'll have the results in the next few days.

roymiles commented 5 months ago

I finished the run but the results were a bit lower than I expected. Though I have just realised that this may be due to having a projector on the teacher side i.e. teacher: auxiliary_model_wrapper:kwargs:linear_kwargs. I am re-running the experiment now, sorry about that! ha

yoshitomo-matsubara commented 5 months ago

Hi @roymiles Good catch! I forgot to remove linear_kwargs from teacher wrapper. Thanks for pointing it out

roymiles commented 5 months ago

Hi @yoshitomo-matsubara

Hopefully this is the final update before I push the log, checkpoint file, and yaml. I was getting poor results because my runs were automatically loading the optimiser and checkpoints from my previous run with the same dst_ckpt. This completely crossed my mind but the log loss is now on par with this repo implementation too and it will likely be finished in a day or so.

yoshitomo-matsubara commented 5 months ago

Hi @roymiles

Does dst_ckpt work as a file path to load the checkpoint?

dst_ckpt is literally a file path to store the checkpoint, but not to load the checkpoint (src_ckpt is a file path for loading the checkpoint)

roymiles commented 5 months ago

This is what I found when trying to debug. The optimiser ended up starting at a much lower lr (in fact the final lr) than the config specified. It ended up being after this line: https://github.com/yoshitomo-matsubara/torchdistill/blob/3799847d0e24b89d22801f75f36f0d075906f928/examples/torchvision/image_classification.py#L132

This was with an empty src_ckpt in the yaml.

yoshitomo-matsubara commented 5 months ago

ah src_ckpt should be used there. I will update the scripts soon. Thanks for pointing it out!

yoshitomo-matsubara commented 5 months ago

By the way, you're also welcomed to advertise your BMVC'22 and AAAI'24 papers at https://yoshitomo-matsubara.net/torchdistill/projects.html#papers as those papers use torchdistill :)

yoshitomo-matsubara commented 3 months ago

Hi @roymiles

It loos like the previous config did not use the normalized representations for computing a loss. Can you run it again with srd-resnet18_from_resnet34.txt and confirm the reproducibility?

yoshitomo-matsubara commented 3 months ago

Never mind, I reran the experiment with the fixed official config. With the fixed official config, I achieved 71.93%, which is even better than the previous numbers (71.63% and 71.65% in your paper and previous torchdistill number)

https://github.com/yoshitomo-matsubara/torchdistill/pull/473

roymiles commented 3 months ago

Hi @yoshitomo-matsubara

Sorry for the late reply and ah that's a complete oversight on my part. It is really great that you spotted this and even better you got better results 😂!

I have only just seen your previous post now, but I would definitely like to put a link/description to these papers on the project page. Thanks so much for this :D I will add a "show and tell discussion" this week.