facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Error in Loading State Dict #27

Closed brucejwittmann closed 3 years ago

brucejwittmann commented 3 years ago

I am running into an error when trying to reload a fine-tuned version of the models. After further-training, models were saved using the below code:

# Load model
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t12_85M_UR50S")

# Training code

# Save model
torch.save(model.state_dict(), BEST_MODEL)

Upon trying to reload the model using the below

# Load model
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t12_85M_UR50S")
model.load_state_dict(torch.load(BEST_MODEL))

I run into the error

RuntimeError                              Traceback (most recent call last)
<ipython-input-4-b6ed0ebd023b> in <module>
----> 1 model.load_state_dict(torch.load("esm1_t12_85M_UR50S-Best.pt"))

~/anaconda3/envs/c10_stability/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1049
   1050         if len(error_msgs) > 0:
-> 1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)

This is a fairly standard pytorch error that you get when trying to load the wrong state_dict into a model. Upon further inspection, it looks like all keys in saved state_dict have a "model." prefix on their names while the keys in the model itself do not (I attached a file of the full error which shows this mismatch in key names).

What do you recommend for bypassing this problem? Is there a specific GitHub repo tag that I should be trying to work off of? It looks like there might have been some change to the torch.hub models between my original training script and trying to load now. Would it be reasonable to just remove the "model." prefix in the saved state_dict?

Thank you!

StateDictLoadError.txt

joshim5 commented 3 years ago

Hi @brucejwittmann, this can happen if you wrap ESM inside a variable called model during fine-tuning and then save the state dict using the default pytorch tools. By default, pytorch will save the state dict based on the variable names you used.

This is a fairly common issue and you can resolve it by "upgrading" the state_dict of your models to match the current codebase. This is often done by implementing an upgrade_state_dict method of a nn.Module subclass or by modifying the model state itself. See here for an example of how we do this in ESM: https://github.com/facebookresearch/esm/blob/master/esm/pretrained.py#L49

Let me know if these pointers aren't helpful or if there is anything else we can do to help.

brucejwittmann commented 3 years ago

Thank you for the help! We now have everything working.

Xinxinatg commented 3 years ago

@brucejwittmann Hi may I know which task are you fine tuning on? I am trying to fine tune the model on a binary classification problem, really appreciate it if you can provide some hints of the details. Do I need to add a extra feed forward network on top of the loaded pre-trained model?

brucejwittmann commented 3 years ago

Hi @Xinxinatg, this issue was in reference to fine-tuning on protein sequence data. We were further training the ESM models on a specific family of protein sequences using masked token prediction. For this specific case, no additional layers would be needed; the "logits" layer output by the ESM model can be passed into a nn.CrossEntropyLoss loss instance (along with appropriate labels for the masked tokens).

For binary classification, you would want to use the "representations" output rather than the logits. You will need to add an additional layer (or more, depending on your objective) on top of this.

Xinxinatg commented 3 years ago

@brucejwittmann Thanks a lot! will try to customize the output layers:)