LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis
MIT License
507 stars 26 forks source link

AssertionError happens when loading the model in the workflow of “finetuning-then-linprobing”. #29

Closed uk9921 closed 1 year ago

uk9921 commented 1 year ago

I used the script main_finetune.py to finetune the pretrained model, and the process went very smoothly. However, when I tried to load the finetuned model and train a linear probe task, I got this AssertionError: File "main_linprobe.py", line 203, in main assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} I printed the msg.missing_keys and got msg.missing_keys = []

So, I wonder if we need to assert the missing keys when we try to load the finetuned model?

image
uk9921 commented 1 year ago

Here is my training args

    main_linprobe.py \
    --batch_size 128 \
    --model vit_large_patch16 \
    --global_pool \
    --finetune ${PRETRAIN_CHKPT} \
    --epochs 90 \
    --blr 0.05 \
    --weight_decay 0.0 \
    --output_dir ${OUTPUT_DIR} \
    --data_path ${IMAGENET_DIR} \
    --dist_eval
LTH14 commented 1 year ago

The linear probing script is designed to load a pre-trained model (without the linear classification head and norm layer before it). Therefore, we have an assertion there to make sure the loaded model does not have those parameters. If you want to linear probe a model with the classification head, I think you can simply comment out the assertion.

uk9921 commented 1 year ago

Thank you for your reply, I followed your suggestion and achieve a higher top1 acc than directly loading the pre-trained model.