pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Convert molecular features to floating-point values #9341

Closed aehrlich1 closed 1 month ago

aehrlich1 commented 1 month ago

Change node feature data type from long to float.

codecov[bot] commented 1 month ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 86.98%. Comparing base (61c47ee) to head (0993f78). Report is 4 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #9341 +/- ## ========================================== - Coverage 87.33% 86.98% -0.35% ========================================== Files 460 471 +11 Lines 30385 30721 +336 ========================================== + Hits 26536 26724 +188 - Misses 3849 3997 +148 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

rusty1s commented 1 month ago

Thanks. This doesn't look correct to me. These are categorical features which are expected to be processed via torch.nn.EmbeddingBag.

aehrlich1 commented 1 month ago

The issue arises when trying to pass the data object through a GCN.

Here a small example where the problem comes up:

import torch.nn.functional as F
from torch_geometric.utils import from_smiles
from torch_geometric.nn import GCNConv

smiles = 'COc1cccc(c2cccc(F)c2C(=O)[O-])c1'
data = from_smiles(smiles)

# data.x = data.x.float()
print(data.x.dtype, data.edge_attr.dtype, data.edge_index.dtype)
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(9, 128)
        self.conv2 = GCNConv(128, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)

        return F.log_softmax(x, dim=1)

model = GCN()
model(data)

This will throw an error "RuntimeError: Found dtype Long but expected Float".

rusty1s commented 1 month ago

Yes, this is expected. The correct usage is:


self.emb = torch.nn.EmbeddingBag(100, 64)

x = self.emb(data.x)
x = conv(x, edge_index)
aehrlich1 commented 1 month ago

Ahhh I see. I must have overlooked this in the documentation. My apologies for opening the PR. It can be closed.