Closed brucejwittmann closed 3 years ago
Hi @brucejwittmann, this can happen if you wrap ESM inside a variable called model
during fine-tuning and then save the state dict using the default pytorch tools. By default, pytorch will save the state dict based on the variable names you used.
This is a fairly common issue and you can resolve it by "upgrading" the state_dict
of your models to match the current codebase. This is often done by implementing an upgrade_state_dict
method of a nn.Module
subclass or by modifying the model state itself. See here for an example of how we do this in ESM: https://github.com/facebookresearch/esm/blob/master/esm/pretrained.py#L49
Let me know if these pointers aren't helpful or if there is anything else we can do to help.
Thank you for the help! We now have everything working.
@brucejwittmann Hi may I know which task are you fine tuning on? I am trying to fine tune the model on a binary classification problem, really appreciate it if you can provide some hints of the details. Do I need to add a extra feed forward network on top of the loaded pre-trained model?
Hi @Xinxinatg, this issue was in reference to fine-tuning on protein sequence data. We were further training the ESM models on a specific family of protein sequences using masked token prediction. For this specific case, no additional layers would be needed; the "logits" layer output by the ESM model can be passed into a nn.CrossEntropyLoss
loss instance (along with appropriate labels for the masked tokens).
For binary classification, you would want to use the "representations" output rather than the logits. You will need to add an additional layer (or more, depending on your objective) on top of this.
@brucejwittmann Thanks a lot! will try to customize the output layers:)
I am running into an error when trying to reload a fine-tuned version of the models. After further-training, models were saved using the below code:
Upon trying to reload the model using the below
I run into the error
This is a fairly standard pytorch error that you get when trying to load the wrong state_dict into a model. Upon further inspection, it looks like all keys in saved state_dict have a "model." prefix on their names while the keys in the model itself do not (I attached a file of the full error which shows this mismatch in key names).
What do you recommend for bypassing this problem? Is there a specific GitHub repo tag that I should be trying to work off of? It looks like there might have been some change to the torch.hub models between my original training script and trying to load now. Would it be reasonable to just remove the "model." prefix in the saved state_dict?
Thank you!
StateDictLoadError.txt