muzairkhattak / multimodal-prompt-learning

[CVPR 2023] Official repository of paper titled "MaPLe: Multi-modal Prompt Learning".
https://muzairkhattak.github.io/multimodal-prompt-learning/
MIT License
578 stars 43 forks source link

Unable to load trained state_dict #50

Closed adsbansal closed 4 months ago

adsbansal commented 4 months ago

Thank you for the clear documentation and instructions!

I am unable to load the trained weights for MaPLe Imagenet.
Downloaded the following weights: maple_imagenet

Using build_model from multimodal-prompt-learning/clip/model.py in conjunction with following code to load model:

import torch
from model import build_model

checkpoint_path = '/home/cse/visitor/abansal.visitor/scratch/baselines/maple/maple_tune/model_files/deep-lang-prompting/model.pth.tar-5'
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))  

design_details = {"trainer": 'CoCoOp',
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0}

model = build_model(state_dict['state_dict'], design_details)

Getting a mismatch of state_dicts. This may not be the right way to load the model, if so can you point me in the right direction. Similar to #45 Trying to use the trained model to further fine tune on a downstream task.

Thanks!

muzairkhattak commented 4 months ago

Hi @adsbansal,

Thank you for showing interest in MaPLe!

Regarding your question, I believe the problem is due to the "CoCoOp" as the trainer value in the design_details dictionary. Can you try replacing the design_details with the following:

design_details = {"trainer": 'MaPLe',
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0,
                      "maple_length": 2}

Hopefully this should resolve the issue.

Kind regards, Sincerely, Muhammad Uzair

adsbansal commented 4 months ago

Thank you for the prompt reply! Unfortunately I am encountering the following error

Traceback (most recent call last):
  File "./load_model.py", line 14, in <module>
    model = build_model(state_dict['state_dict'], design_details)
  File "./model.py", line 668, in build_model
    vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
KeyError: 'visual.layer1.0.conv1.weight'

It is an error which is occurring before the CLIP model is being loaded in line # 681

Thanks!

muzairkhattak commented 4 months ago

Hi @adsbansal,

Got it. I have looked into the main code and found that the state dict should be CLIP weights, not MaPLe, as we are building CLIP as the first step.

Once we have obtained the CLIP mode, we only need to load MaPLe weights (as you have done above).

You can do the following: Starting with this line in function build_model() of MaPLe trainer, you need to execute them all and obtain MaPLe model in the end. .

Lastly, then you can load the pretrained MaPLe weights and apply them on MaPLe model as shown in these lines.

adsbansal commented 4 months ago

Thanks @muzairkhattak for the prompt help! The code has been impeccably maintained and written.