tbepler / protein-sequence-embedding-iclr2019

Source code for "Learning protein sequence embeddings using information from structure" - ICLR 2019
Other
253 stars 75 forks source link

How can I load the pretrained models into pytorch? #13

Closed rainwala closed 4 years ago

rainwala commented 4 years ago

Hi Tristan,

what format are the pre-trained models stored in? How can I load them into pytorch?

Best, Ali

rainwala commented 4 years ago

import torch model = torch.nn.Module() model.load_state_dict(torch.load('bepler_models/pfam_lm_lstm2x1024_tied_mb64.sav'))

rainwala commented 4 years ago

ModuleNotFoundError Traceback (most recent call last)

in 1 import torch 2 model = torch.nn.Module() ----> 3 model.load_state_dict(torch.load('bepler_models/pfam_lm_lstm2x1024_tied_mb64.sav')) ~/.local/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args) 384 f = f.open('rb') 385 try: --> 386 return _load(f, map_location, pickle_module, **pickle_load_args) 387 finally: 388 if new_fd: ~/.local/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module, **pickle_load_args) 571 unpickler = pickle_module.Unpickler(f, **pickle_load_args) 572 unpickler.persistent_load = persistent_load --> 573 result = unpickler.load() 574 575 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) ModuleNotFoundError: No module named 'src'
tbepler commented 4 years ago

It sounds like there are two issues here. The first is that the code cannot be found by python. For this, you need to either run python from within the project base directory or link this base directory on your python path. This stackoverflow thread might help.

For the second, you just need to use

model = torch.load(path)

to load the saved model.

rainwala commented 4 years ago

Thanks Tristan, but I can't seem to fix the first problem, even with the workarounds suggested in the stackoverflow link.

tbepler commented 4 years ago

Are you running your code from within the base directory of this project?

rainwala commented 4 years ago

I have tried running it in my home directory. I've also tried a virtual environment (using pip as the package manager), both in the same directory as the models and in a higher level directory. All attempts give me the same error message.

Is there a different definition for the home directory of the project that I am missing? Thank you for your help.

On Mon, Jan 6, 2020 at 4:51 PM Tristan Bepler notifications@github.com wrote:

Are you running your code from within the base directory of this project?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/tbepler/protein-sequence-embedding-iclr2019/issues/13?email_source=notifications&email_token=ADRDOMWMNG67OLQ5E5ION7LQ4NOQBA5CNFSM4JPAJ5S2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIGBDBY#issuecomment-571216263, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADRDOMRCQTAEARHBCARJOH3Q4NOQBANCNFSM4JPAJ5SQ .

tbepler commented 4 years ago

If you put the script you are running in the base directory of the embeddings, e.g.

protein-sequence-embedding-iclr2019/
|-- your script.py
|-- src/
...

then your script will be able to find code in the src directory correctly. If you need to have your script somewhere else, then you need to add this directory to your python path for the code in the src directory to be found. In the stackoverflow discussion, the last part of this section describes one way to do this. This is a general python issue, so I recommend reading about how python finds packages if the above isn't enough to fix the import error.

rainwala commented 4 years ago

Thanks Tristan, I've managed to solve the original problem.

Now, instead, I get the following error:

import torch model = torch.nn.Module() model.load_state_dict(torch.load('src/models/pretrained_models/pfam_lm_lstm2x1024_tied_mb64.sav'))


AttributeError Traceback (most recent call last)

in 2 import torch 3 model = torch.nn.Module() ----> 4 model.load_state_dict(torch.load('src/models/pretrained_models/pfam_lm_lstm2x1024_tied_mb64.sav')) ~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 816 # copy state_dict so _load_from_state_dict can modify it 817 metadata = getattr(state_dict, '_metadata', None) --> 818 state_dict = state_dict.copy() 819 if metadata is not None: 820 state_dict._metadata = metadata ~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __getattr__(self, name) 589 return modules[name] 590 raise AttributeError("'{}' object has no attribute '{}'".format( --> 591 type(self).__name__, name)) 592 593 def __setattr__(self, name, value): AttributeError: 'BiLM' object has no attribute 'copy'
tbepler commented 4 years ago

You don't need model.load_stat_dict, etc.

model = torch.load(path)

is all that is needed to load the model.

rainwala commented 4 years ago

Thanks! It works now.