pengzhiliang / G2SD

81 stars 3 forks source link

[NEED HELP] Can't reproduce Generic Distillation experiment in ImageNet1k #7

Closed terry-for-github closed 1 year ago

terry-for-github commented 1 year ago

Hi! Thanks for the great job!

I encountered some difficulties in reproducing generic distillation experiments. I use official mae_pretrain_base.pth as teacher model and mae_vit_small_patch16_dec256d4b as student network. Then I kept the original code unchanged and only modified the GD.sh script to conduct the Generic Distillation experiment. But the loss only decreased during the first few epochs and was stuck at about 0.2 in the rest epochs, which was far away from the official result in the google drive log(about 0.02). I doubted my training process and finally verified it in the image classification task with Specific Distilltion. After the first epoch in SD, I only got about 36% accuracy in ImageNet1k(official log is 56.34%). Also I changed the seed to 1 and 42 then conducted the GD experiment on ImageNet1k again, still got the similar loss(0.19/0.18). I've checked the implement detail in the paper and found no conflict. So I can't reproduce the Generic Distillation exp. Is there something I need to modify in the code? Or was there something I'v done wrong? How can I reproduce the GD experiment? I'll really appreciate it if someone could help me! Thank you very much!

Here is my GD.sh script https://drive.google.com/file/d/121dMog7rW5I7gGz8U2c-DzDM4s_x3JOW/view?usp=drive_link And the program running log before the training epochs https://drive.google.com/file/d/15Z9yHOuR1Tzp8JdOwSYvEN7SGJZtfUfy/view?usp=drive_link And my python environment (pip list) https://drive.google.com/file/d/1dWQPJ87wc6anvwmsNkVC0aXh9GmQxL_S/view?usp=drive_link And the output log file https://drive.google.com/file/d/1uv53nyKq_-msFIoZA5H76brSdwHyBvzc/view?usp=drive_link And the tensorboard result image lr: https://drive.google.com/file/d/1KCGBFOFMM3zPA6QVUu8PK3PbD8Jf_0O4/view?usp=drive_link loss:https://drive.google.com/file/d/1pSQLhKV4K3eErnm9RA5yeLD4GCluJhnx/view?usp=drive_link

And finally my SD log in the first 4 epochs: {"train_lr": 0.0003996802557953637, "train_loss": 5.64891408925815, "train_loss_gt": 5.9685534642373534, "train_loss_dis": 5.3292747152318585, "test_loss": 2.9740411103988182, "test_acc1": 36.42200002670288, "test_acc5": 64.59000001724243, "epoch": 0, "n_parameters": 22436048} {"train_lr": 0.0011996802557953635, "train_loss": 4.463338032424402, "train_loss_gt": 5.090619727695207, "train_loss_dis": 3.8360563386544335, "test_loss": 2.264968312212399, "test_acc1": 48.64200002410889, "test_acc5": 74.98000002075196, "epoch": 1, "n_parameters": 22436048} {"train_lr": 0.001999680255795364, "train_loss": 4.1618424449369105, "train_loss_gt": 4.8575777951285515, "train_loss_dis": 3.4661070930181173, "test_loss": 2.061447801638623, "test_acc1": 52.56200001815796, "test_acc5": 77.82600002288818, "epoch": 2, "n_parameters": 22436048} {"train_lr": 0.002799680255795364, "train_loss": 4.027493118024845, "train_loss_gt": 4.751268744706917, "train_loss_dis": 3.303717490866316, "test_loss": 1.9335608093106016, "test_acc1": 54.88400001312256, "test_acc5": 79.80000002593994, "epoch": 3, "n_parameters": 22436048}

Vickeyhw commented 1 year ago

@terry-for-github The teacher checkpoint you used doesn't include the decoder weight, which could be seen in 'missing keys: .......' from 'my_program_out.txt'. You could download the checkpoint from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

terry-for-github commented 1 year ago

Thank you for your timely reply. I think I should download the base version instead. Anyway, that helps me a lot.

terry-for-github commented 1 year ago

@Vickeyhw Just in case I miss some weight again, was this output correct in Specific Distillation (for image classification task)? I downloaded the official mae_finetuned_vit_base.pth from https://github.com/facebookresearch/mae/blob/main/FINETUNE.md

[18:37:23.661471] Load pre-trained checkpoint from: ViT_Sm_GD.pth [18:37:23.697086] _IncompatibleKeys(missing_keys=['dist_token', 'head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'decoder_blocks.1.norm2.bias', 'decoder_blocks.1.mlp.fc1.weight', 'decoder_blocks.1.mlp.fc1.bias', 'decoder_blocks.1.mlp.fc2.weight', 'decoder_blocks.1.mlp.fc2.bias', 'decoder_blocks.2.norm1.weight', 'decoder_blocks.2.norm1.bias', 'decoder_blocks.2.attn.qkv.weight', 'decoder_blocks.2.attn.qkv.bias', 'decoder_blocks.2.attn.proj.weight', 'decoder_blocks.2.attn.proj.bias', 'decoder_blocks.2.norm2.weight', 'decoder_blocks.2.norm2.bias', 'decoder_blocks.2.mlp.fc1.weight', 'decoder_blocks.2.mlp.fc1.bias', 'decoder_blocks.2.mlp.fc2.weight', 'decoder_blocks.2.mlp.fc2.bias', 'decoder_blocks.3.norm1.weight', 'decoder_blocks.3.norm1.bias', 'decoder_blocks.3.attn.qkv.weight', 'decoder_blocks.3.attn.qkv.bias', 'decoder_blocks.3.attn.proj.weight', 'decoder_blocks.3.attn.proj.bias', 'decoder_blocks.3.norm2.weight', 'decoder_blocks.3.norm2.bias', 'decoder_blocks.3.mlp.fc1.weight', 'decoder_blocks.3.mlp.fc1.bias', 'decoder_blocks.3.mlp.fc2.weight', 'decoder_blocks.3.mlp.fc2.bias', 'decoder_norm.weight', 'decoder_norm.bias', 'decoder_pred.weight', 'decoder_pred.bias', 'encoderfeature_pred.fc1.weight', 'encoderfeature_pred.fc1.bias', 'encoderfeature_pred.fc2.weight', 'encoderfeature_pred.fc2.bias', 'decoderfeature_pred.fc1.weight', 'decoderfeature_pred.fc1.bias', 'decoderfeature_pred.fc2.weight', 'decoderfeature_pred.fc2.bias']) [18:37:24.036160] Load pre-trained teacher checkpoint from: /userhome/download/mae_finetuned_vit_base.pth

terry-for-github commented 1 year ago

Thanks a lot. I've reproduce the GD and SD experiment result with ViT-small and ViT-tiny on image classification task.