Closed vmuel closed 1 year ago
.from_pretrained()
is the method inherited from huggingface transformers.
you can set checkpoint.resume_file
and config
in the command line to load your model.
Please check README carefully.
Yes, indeed. I am trying to load the transformer used as a fusion encoder in order to generate text, image and joint embeddings. I have noticed that the class FashionViLBase from mmf/models/fashionvil/base.py has in-built methods to generate embeddings, so this is what I have been focused on so far. I have initiated a instance of the class this way:
config = OmegaConf.load('projects/fashionvil/configs/e2e_pretraining_final.yaml')
fashionvil_config = config.model_config.fashionvil
fashionvil_config = OmegaConf.to_container(fashionvil_config, resolve=True)
fashionvil_config = BertConfig.from_dict(fashionvil_config)
transformer = FashionViLBase(fashionvil_config)
Then, I try to load the weights and this is where I am unsuccessful:
pretrained_path = 'save/fashionvil_e2e_pretrain_final/fashionvil_final.pth'
state_dict = torch.load(pretrained_path, map_location=torch.device("cpu"))
transformer.load_state_dict(state_dict)
Running the code results in a RuntimeError: Error(s) in loading state_dict for FashionViLBase: Missing key(s) in state_dict: "embeddings.position_ids"...
(although the keys are present in the state_dict-maybe is it due to the format of the file?).
My question is: do you know of a working way to load the weights of a FashionViLBase model ? Or is there a simpler way to generate embeddings? This is sort of a stretch compared to the tasks you explain in your README so I completely understand if you don't have the time to answer.
Hope I am being more clear! Thanks in advance.
Hi,
I think this bug is because of the mismatch between state_dict keys rather than file formats. Could you please check the keys of your loaded state_dict? I guess there might be some prefixes like "model.", which will cause the mismatch.
In terms of generating embeddings, I think you can modify some forward passes, like https://github.com/BrandonHanx/mmf/issues/2
Hi,
Thanks a lot! There was two issues with the state_dict and one of them was indeed a mismatch due to prefixes. The other problem was that the state_dict contained the keys for the image encoder & image tokenizer while I only needed the weights of the transformer.
The following code solved my problem:
new_state_dict = OrderedDict()
for k in state_dict .keys():
if k.startswith('module.model.bert.'):
newstring = k[18:]
new_state_dict [newstring] = state_dict [k]
This allowed me to load the model’s weight. The built-in methods to get embeddings seem to work so far but i’ll take a look at #2.
Thank you for your help!
❓ Questions and Help
Hi, thanks for your work and making your code available!
I am trying to load the transformer used for both TE & FE by importing the
FashionViLBase
class and using the inherited.from_pretrained()
method, which requires a .bin and a config.json file.I’ve tried doing so using the fashionvil_final.pth checkpoint and the e2e_pretrain_final.yaml config file but to no avail.
Do you know if there is any other working way to load the trained transformer?
Thanks a lot in advance!