Ramprasad-Group / polygnn

polyGNN is a Python library to automate ML model training for polymer informatics.
Other
32 stars 5 forks source link

Use graph_feats in the forward pass. #14

Closed rishigurnani closed 1 year ago

rishigurnani commented 1 year ago

Prior to this PR, graph_feats was not being used in the forward pass. This meant that predictions would be the same for polymer with the same SMILES and selector, even if they had different graph_feats. This issue is fixed in this PR. One change is that an argument was added to the polyGNN class to specify the dimension of graph_feats.

I also added a test to make sure in the future that this bug does not occur again.

oliverhvidsten commented 1 year ago

I just tested the same dataset on main and the new branch. The one on main ran, but the one in the new branch errored.

Traceback pasted below

..Training model 0 (with capacity 2) of 4 Traceback (most recent call last): File "GNN_CV_training.py", line 362, in optimal_capacity = session.choose_model_size_by_overfit() File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/nndebugger/dl_debug.py", line 387, in choose_model_size_by_overfit start=start, File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/nndebugger/torch_utils.py", line 72, in default_per_epoch_trainer output = model(data) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/data/oliver/polygnn/polygnn/models.py", line 83, in forward data.x = self.final_mlp(data.x) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/polygnn_trainer/layers.py", line 106, in forward x = layer(x) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/polygnn_trainer/layers.py", line 41, in forward return self.dropout(self.activation(self.linear(x))) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (50x134 and 128x64)

rishigurnani commented 1 year ago

One change is that an argument was added to the polyGNN class to specify the dimension of graph_feats.

Did you specify the dimension of graph_feats using the new argument graph_feats_dim?

rishigurnani commented 1 year ago

Actually I realized there's another issue. I'll let you know when it's fixed.

rishigurnani commented 1 year ago

@oliverhvidsten OK, now please try again. You'll need the new polygnn_trainer version (v0.5.0) as well. Do you have it? If not, please run poetry update polygnn_trainer.

oliverhvidsten commented 1 year ago

Getting this error. Any thoughts?

Traceback (most recent call last): File "GNN_CV_training.py", line 369, in optimal_capacity = session.choose_model_size_by_overfit() File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/nndebugger/dl_debug.py", line 387, in choose_model_size_by_overfit start=start, File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/nndebugger/torch_utils.py", line 72, in default_per_epoch_trainer output = model(data) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/data/oliver/polygnn/polygnn/models.py", line 82, in forward data.x = self.assemble_data(data) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/polygnn_trainer/std_module.py", line 59, in assemble_data return cat((data.yhat, data.graph_feats, data.selector), dim=1) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch_geometric/data/data.py", line 441, in getattr return getattr(self._store, key) File "/home/oliver/.cache/pypoetry/virtualenvs/polygnn-5wmT02iB-py3.7/lib/python3.7/site-packages/torch_geometric/data/storage.py", line 82, in getattr f"'{self.class.name}' object has no attribute '{key}'") AttributeError: 'GlobalStorage' object has no attribute 'yhat'

rishigurnani commented 1 year ago

It's not obvious to me why that is happening, since example2.py runs without issue. Can you send a reproducible example?

oliverhvidsten commented 1 year ago

Oops. Forgot to pull new changes. That error is no longer occurring.

oliverhvidsten commented 1 year ago

The plots are no longer invariant across the features contained within graph_feats. The accuracy of the predictions I just made were not great, but it is likely because I was using a very small part of the overall dataset. I will test again with the full dataset and see what information I get back. This will take a bit though.

rishigurnani commented 1 year ago

The plots are no longer invariant across the features contained within graph_feats. The accuracy of the predictions I just made were not great, but it is likely because I was using a very small part of the overall dataset. I will test again with the full dataset and see what information I get back. This will take a bit though.

Awesome! I'll go ahead and close #13 then. When you get the parity plots generated using the new code and the full dataset can you do me a favor and add them to #13 for posterity? You can reopen the issue too then if necessary.