kohjingyu / fromage

🧀 Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".
https://jykoh.com/fromage
Apache License 2.0
474 stars 35 forks source link

Unexpected key(s) in state_dict #39

Closed mlarrarte closed 3 months ago

mlarrarte commented 3 months ago

Hey! I'm trying to load the checkpoint provided to test the model a little bit but I got the following error:

Unexpected key(s) in state_dict: "model.logit_scale", "model.text_hidden_fcs.0.0.bias", "model.text_hidden_fcs.0.0.weight", "model.visual_embeddings.bias", "model.visual_embeddings.weight", "model.visual_fc.bias", "model.visual_fc.weight", "ret_input_embeddings.weight".

I know that in issue #29 a solution was provided but I don't know which script I need to modify.

Can you help me please?

NMesaC commented 1 month ago

For anyone who wants a faster fix to this issue, the following worked for me (using solutions from #29):

1) To get an unpruned checkpoint, run the following command in the fromage directory: randport=$(shuf -i8000-9999 -n1) # Generate a random port number python -u main.py \ --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \ --multiprocessing-distributed --world-size 1 --rank 0 \ --epochs=1 --steps-per-epoch=10 \ --dataset=cc3m --val-dataset=cc3m \ --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \ --exp_name='fromage_exp' --image-dir='images/' --log-base-dir='runs/' \ --batch-size=180 --val-batch-size=100 --learning-rate=0.0003 --precision='bf16' --print-freq=100

2) Copy the following code into a python file, and run it (you might need to change directory locations) `import torch import subprocess import json import os from fromage import models, utils

checkpoint1 = torch.load('./fromage_model/pretrained_ckpt.pth.tar', map_location='cuda:0') checkpoint2 = torch.load('./runs/fromage_exp/ckpt.pth.tar', map_location='cuda:0')

weights = {'model.logit_scale':'module.model.logit_scale', 'model.text_hidden_fcs.0.0.bias':'module.model.text_hidden_fcs.0.0.bias', 'model.text_hidden_fcs.0.0.weight':'module.model.text_hidden_fcs.0.0.weight', 'model.visual_embeddings.bias':'module.model.visual_embeddings.bias', 'model.visual_embeddings.weight':'module.model.visual_embeddings.weight', 'model.visual_fc.bias':'module.model.visual_fc.bias', 'model.visual_fc.weight':'module.model.visual_fc.weight', 'ret_input_embeddings.weight':'module.model.input_embeddings.weight'}

write

with open('./runs/fromage_exp/model_args.json', 'r') as f: model_kwargs = json.load(f) ret_token_idx = model_kwargs['retrieval_token_idx']

for k,v in weights.items(): if k == 'ret_input_embeddings.weight': checkpoint2['state_dict'][v][ret_token_idx:ret_token_idx+1, :] = checkpoint1['state_dict'][k] else: checkpoint2['state_dict'][v] = checkpoint1['state_dict'][k] print(k)

torch.save(checkpoint2, './fromage_model/ckpt' + '.pth.tar')`

Now, if you run main.py, it will work as expected. I did this in the context of running main.py for a finetuning task, if that helps.