lambert-x / medical_mae

The official implementation of "Delving into Masked Autoencoders for Multi-Label Thorax Disease Classification"
Apache License 2.0
67 stars 15 forks source link

Loading pre-trained models #5

Closed mahbodez closed 1 year ago

mahbodez commented 1 year ago

Hi, There seems to be an issue when I try to load the vit-b_CXR_0.5M_mae.pth file as a ViTMAE base with 16x16 patch size using models_mae.py module. here is my code:

import models_mae as mm
vitmae = mm.mae_vit_base_patch16()
vitmae.load_state_dict(
    state_dict=torch.load(
        f="vit-b_CXR_0.5M_mae.pth")
)

I get the following error:

Error(s) in loading state_dict for MaskedAutoencoderViT:
    Missing key(s) in state_dict: "cls_token", "pos_embed", ...
    Unexpected key(s) in state_dict: "model", "optimizer", "epoch", "scaler", "args". 
lambert-x commented 1 year ago

Please refer to https://github.com/lambert-x/medical_mae/blob/64e502b030c5986b9e78925b389ee3d52cba882b/main_finetune_chestxray.py#L283-L306 to load the pre-trained checkpoint for finetuning.

mahbodez commented 1 year ago

Thank you very much!