lambert-x / medical_mae

The official implementation of "Delving into Masked Autoencoders for Multi-Label Thorax Disease Classification"
Apache License 2.0
67 stars 15 forks source link

Incompatible Model Architecture and Saved Weights Issue #6

Closed SaumyaBhandari closed 1 year ago

SaumyaBhandari commented 1 year ago

I encountered an issue while attempting to utilize the provided saved model weights for the architecture implemented in the code repository. The saved model (finetuned on NIH ChestXRay) 'vit-b_CXR_0.5M_mae.pth' weights seem to be incompatible with the architecture initialization in the codebase, resulting in errors during loading and execution.

Steps to Reproduce:

  1. Clone the repository.
  2. Follow the provided instructions to run main_finetune_chestxray.py with saved finetuned weights : 'vit-b_CXR_0.5M_mae.pth' and initializing the model as vit_base_patch16
  3. Attempt to load the provided saved model weights. [line 282-306 in main_finetune_chestxray.py]
  4. Observe errors indicating an architectural mismatch.

Expected Behavior: The saved model weights should seamlessly load and align with the architecture defined in the code.

Actual Behavior: The saved model weights are not compatible with the architecture initialization, leading to runtime errors.

Identified Issue:

Additional Information: The repository's documentation provides saved model weights. The architecture initialization in the code is consistent with the provided guidelines. The error message suggests a mismatch between the loaded model and the initialized architecture. The issue prevents further utilization of the model weights for desired tasks. This issue significantly impacts the ease of integrating and utilizing the repository's pre-trained models. A resolution or guidance on properly aligning the model weights with the code's architecture would be greatly appreciated.

SaumyaBhandari commented 1 year ago

The weights to the model are loaded. I've edited the code to load weights of the model more seamlessly.

if 'vit' in args.model:
    model = models_vit.__dict__[args.model](
        img_size=args.input_size,
        num_classes=args.nb_classes,
        drop_rate=args.vit_dropout_rate,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
     )

if args.finetune:
    if 'vit' in args.model:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % args.finetune)
        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()

        for k in checkpoint_model.keys():
            if k in state_dict:
                if checkpoint_model[k].shape == state_dict[k].shape:
                    state_dict[k] = checkpoint_model[k]
                    print(f"Loaded Index: {k} from Saved Weights")
                else:
                    print(f"Shape of {k} doesn't match with {state_dict[k]}")
            else:
                print(f"{k} not found in Init Model")

        # interpolate position embedding
        interpolate_pos_embed(model, state_dict)

        # load pre-trained model
        model.load_state_dict(state_dict)