tbepler / protein-sequence-embedding-iclr2019

Source code for "Learning protein sequence embeddings using information from structure" - ICLR 2019
Other
253 stars 75 forks source link

embed_sequences.py fails: AttributeError: 'LSTM' object has no attribute '_flat_weights' #21

Closed konstin closed 3 years ago

konstin commented 3 years ago

I'm trying to run embed_sequences.py, but I get an exception from torch code:

$ python embed_sequences.py seqwence-protein.fasta -m ../pretrained_models/ssa_L1_100d_lstm3x512_lm_i512_mb64_tau0.5_lambda0.1_p0.05_epoch100.sav -o output.h5
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.rnn.LSTM' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
# writing: a.h5
# embedding with lm_only=False, no_lm=False, proj_only=False
# pooling: none
Traceback (most recent call last):
  File "embed_sequences.py", line 184, in <module>
    main()
  File "embed_sequences.py", line 171, in main
    z = embed_sequence(sequence, lm_embed, lstm_stack, proj
  File "embed_sequences.py", line 82, in embed_sequence
    z = embed_stack(x, lm_embed, lstm_stack, proj
  File "embed_sequences.py", line 49, in embed_stack
    h = lm_embed(x)
  File "/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/konsti/bepler-berger/src/models/embedding.py", line 25, in forward
    h_lm = self.lm.encode(x)
  File "/home/konsti/bepler-berger/src/models/sequence.py", line 168, in encode
    h_fwd_layers,h_rvs_layers = self.transform(z_fwd, z_rvs)
  File "/home/konsti/bepler-berger/src/models/sequence.py", line 92, in transform
    h,_ = rnn(h)
  File "/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 569, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  File "/home/konsti/bepler-berger/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 593, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'LSTM' object has no attribute '_flat_weights'

These are the versions I installed (from pip freeze):

Cython==0.29.21
dataclasses==0.6
future==0.18.2
h5py==3.1.0
numpy==1.19.4
pkg-resources==0.0.0
torch==1.5.1
typing-extensions==3.7.4.3

I suspect that this has something to do with the torch version, so would it be possible to make the model run on more recent pytorch versions?

tbepler commented 3 years ago

The easy solution is to just use an older version of pytorch to calculate the embeddings. If you need to use a newer version, the permanent solution to this is probably to use the older version of pytorch to save only the state_dict for the model (the saved models I provide are pickles of the whole model object, not just the state_dict), then load those state_dicts into the model in the new version of pytorch.

konstin commented 3 years ago

Thank you for the quick response, that made it work!