Closed aehrlich1 closed 1 month ago
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 86.98%. Comparing base (
61c47ee
) to head (0993f78
). Report is 4 commits behind head on master.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thanks. This doesn't look correct to me. These are categorical features which are expected to be processed via torch.nn.EmbeddingBag
.
The issue arises when trying to pass the data object through a GCN.
Here a small example where the problem comes up:
import torch.nn.functional as F
from torch_geometric.utils import from_smiles
from torch_geometric.nn import GCNConv
smiles = 'COc1cccc(c2cccc(F)c2C(=O)[O-])c1'
data = from_smiles(smiles)
# data.x = data.x.float()
print(data.x.dtype, data.edge_attr.dtype, data.edge_index.dtype)
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(9, 128)
self.conv2 = GCNConv(128, 2)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
return F.log_softmax(x, dim=1)
model = GCN()
model(data)
This will throw an error "RuntimeError: Found dtype Long but expected Float".
Yes, this is expected. The correct usage is:
self.emb = torch.nn.EmbeddingBag(100, 64)
x = self.emb(data.x)
x = conv(x, edge_index)
Ahhh I see. I must have overlooked this in the documentation. My apologies for opening the PR. It can be closed.
Change node feature data type from long to float.