ayushkarnawat / profit

Exploring evolutionary protein fitness landscapes
MIT License
1 stars 0 forks source link

[RuntimeError] Unable to train EmbeddedGCN model #79

Closed ayushkarnawat closed 4 years ago

ayushkarnawat commented 4 years ago

When attempting to train the 3gb1 model using the EmbeddedGCN model, we get a RuntimeError (see full output below).

https://github.com/ayushkarnawat/profit/blob/1f8cb98986646259cea45ac79cf004a9e43a2bc4/examples/3gb1/train.py#L1-L99

Current behavior

Loading preprocessed data from cache `data/3gb1/processed/egcn_fitness/tertiary5.mdb`
Training...
Traceback (most recent call last):
  File "examples/3gb1/train.py", line 71, in <module>
    train_y_pred = model([atoms, adjms, dists])
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/ayushkarnawat/Documents/dev/python_workspace/profit/profit/models/pytorch/egcn.py", line 1120, in forward
    sc_s = self.hidden_layers[f"s_to_s_{i}"](sc)
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/ayushkarnawat/Documents/dev/python_workspace/profit/profit/models/pytorch/egcn.py", line 646, in forward
    scalar_features = torch.matmul(scalar_features, self.weight)
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm

Expected behavior

The model complies, trains, and the training and val losses decrease over time.