BrandonHanx / mmf

[ECCV 2022] FashionViL: Fashion-Focused V+L Representation Learning
https://mmf.sh/
Other
58 stars 7 forks source link

Loading pretrained transformer #5

Closed vmuel closed 1 year ago

vmuel commented 2 years ago

❓ Questions and Help

Hi, thanks for your work and making your code available!

I am trying to load the transformer used for both TE & FE by importing the FashionViLBase class and using the inherited .from_pretrained() method, which requires a .bin and a config.json file.

I’ve tried doing so using the fashionvil_final.pth checkpoint and the e2e_pretrain_final.yaml config file but to no avail.

Do you know if there is any other working way to load the trained transformer?

Thanks a lot in advance!

BrandonHanx commented 2 years ago

.from_pretrained() is the method inherited from huggingface transformers.

you can set checkpoint.resume_file and config in the command line to load your model.

Please check README carefully.

vmuel commented 1 year ago

Yes, indeed. I am trying to load the transformer used as a fusion encoder in order to generate text, image and joint embeddings. I have noticed that the class FashionViLBase from mmf/models/fashionvil/base.py has in-built methods to generate embeddings, so this is what I have been focused on so far. I have initiated a instance of the class this way:

config = OmegaConf.load('projects/fashionvil/configs/e2e_pretraining_final.yaml')
fashionvil_config = config.model_config.fashionvil
fashionvil_config = OmegaConf.to_container(fashionvil_config, resolve=True)
fashionvil_config = BertConfig.from_dict(fashionvil_config)
transformer = FashionViLBase(fashionvil_config)

Then, I try to load the weights and this is where I am unsuccessful:

pretrained_path = 'save/fashionvil_e2e_pretrain_final/fashionvil_final.pth'
state_dict = torch.load(pretrained_path, map_location=torch.device("cpu"))
transformer.load_state_dict(state_dict)

Running the code results in a RuntimeError: Error(s) in loading state_dict for FashionViLBase: Missing key(s) in state_dict: "embeddings.position_ids"... (although the keys are present in the state_dict-maybe is it due to the format of the file?).

My question is: do you know of a working way to load the weights of a FashionViLBase model ? Or is there a simpler way to generate embeddings? This is sort of a stretch compared to the tasks you explain in your README so I completely understand if you don't have the time to answer.

Hope I am being more clear! Thanks in advance.

BrandonHanx commented 1 year ago

Hi,

I think this bug is because of the mismatch between state_dict keys rather than file formats. Could you please check the keys of your loaded state_dict? I guess there might be some prefixes like "model.", which will cause the mismatch.

In terms of generating embeddings, I think you can modify some forward passes, like https://github.com/BrandonHanx/mmf/issues/2

vmuel commented 1 year ago

Hi,

Thanks a lot! There was two issues with the state_dict and one of them was indeed a mismatch due to prefixes. The other problem was that the state_dict contained the keys for the image encoder & image tokenizer while I only needed the weights of the transformer.

The following code solved my problem:

new_state_dict = OrderedDict()
for k in state_dict .keys():
    if k.startswith('module.model.bert.'):
        newstring = k[18:]
        new_state_dict [newstring] = state_dict [k]

This allowed me to load the model’s weight. The built-in methods to get embeddings seem to work so far but i’ll take a look at #2.

Thank you for your help!