Hao840 / OFAKD

PyTorch code and checkpoints release for OFA-KD: https://arxiv.org/abs/2310.19444
81 stars 11 forks source link

Reproducing CIFAR-100 - Training Student and Teacher Models from Scratch #21

Open anithselva opened 5 months ago

anithselva commented 5 months ago

Hello,

I'm having some challenges reproducing the values in Table 2 (CIFAR-100). Hoping to get some feedback on how to fix this:

(1) Train from Scratch: Could you please advise on the correct way to use the framework to train the student from scratch? I tried the following for ResNe18, but got very high results (95%+) that are different from the results reported in the paper.

Config used is the default. Here is the command I'm using: python -m torch.distributed.launch --nproc_per_node=1 train.py /home/me/kd/cifar100 --config configs/cifar/cnn.yaml --model resnet18 --teacher swin_tiny_patch4_window7_224 --teacher-pretrained models/swin_tiny_patch4_window7_224.pth --distiller Vanilla

Is this correct? Please correct me if I'm wrong, but I'm using the Vanilla distiller to train the student from scratch on the hard labels and ignoring the teacher.

(2) Fine-Tuned Teacher Could you share the sources and configs you used to train the teacher models from scratch (Swin-T, ViT-S, Mixer-B/16)?

(3) Release Configs If it is not possible to release the models, is it possible for you to release the configs that were used to train the models? It would be immensely helpful for the debugging process.

Looking forward to your reply.

Thanks!

Hao840 commented 5 months ago

For the 1st question, I think you maybe there are some mistakes with your dataset, as it is nearly impossible to achieve a top-1 test accuracy of 95% with resnet18. As for the training of teacher models, I hope discussions in this issue would be helpful.

anithselva commented 5 months ago

For the 1st question, I think you maybe there are some mistakes with your dataset, as it is nearly impossible to achieve a top-1 test accuracy of 95% with resnet18. As for the training of teacher models, I hope discussions in this issue would be helpful.

Yes I thought it was odd too. but I discovered the bug. It was indeed in the dataset. I had processed the raw CIFAR100 dataset into a folder structure that was not recognized by the torchvision CIFAR class and instead it was testing on the train set.

RE: training of teacher models: Thanks for the link to that discussion. Am a bit confused if you trained the teachers on CIFAR using their official repo, or using this framework with the teacher (ImageNet pretrained) as a student and using Vanilla distiller. Could you please clarify?

anithselva commented 5 months ago

I was able to get results that are closer to the reported one.

I noticed in https://github.com/Hao840/OFAKD/issues/14#issuecomment-1847214068 you corrected a bug in the code for CIFAR-100 regarding the data distribution, and was able to get 80% accuracy on the distilled model. Could you share what the performance of the newly trained teacher in this case was?

anithselva commented 5 months ago

I was able to get results that are closer to the reported one.

I noticed in #14 (comment) you corrected a bug in the code for CIFAR-100 regarding the data distribution, and was able to get 80% accuracy on the distilled model. Could you share what the performance of the newly trained teacher in this case was?

@Hao840 Your advice would be appreciated. I'm having a lot of challenges reproducing the results of the OFA method:

Swin-T (trained from scratch) 87% ResNet (trained from scratch) 77% ResNet (distilled): 77%

I went through the other thread you shared but was not able to find anything that helped to show OFA distillation improved student performance.

Could you share any additional information to help us reproduce the results. (i.e. checkpoints, commands, arguments) I followed the exact same command you shared in the other thread regarding the updated CIFAR-100 results

python -m torch.distributed.launch --nproc_per_node=8 train.py /cache/data/cifar/ --dataset cifar100 --config configs/cifar/cnn.yaml --model resnet18 --teacher swin_tiny_patch4_window7_224 --teacher-pretrained /cache/ckpt/swin_tiny_patch4_window7_224_cifar100.pth --num-classes 100 --lr 0.1 --min-lr 1e-3 --weight-decay 5e-3 --ofa-loss-weight 0.1