pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.21k stars 3.64k forks source link

from_smiles function should output x and edge_attr as dtype float #5949

Closed TShimko126 closed 1 year ago

TShimko126 commented 1 year ago

🚀 The feature, motivation and pitch

Feature

Thank you for continued development of and support for the PyG package!

I'd like to request a small change to the implementation of the from_smiles function in torch_geometric.utils.

Specifically, I'd like to request a change for the dtype of the x and edge_attr features of the data object returned by the from_smiles function from Long to Float.

The relevant lines in the code base are line 96 for x and line 113 for edge_attr:

x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

Motivation

Under the current implementation of the function, the node and edge features (stored in x and edge_attr, respectively) are returned with dtype Long. However, most if not all of the layers that act upon these attributes (convolution, pooling, etc.) expect the Float dtype. This makes it impossible to use the direct output of this function in the forward pass of most PyG-based models.

In the example below, I try to feed the output of from_smiles directly into a 1-layer GCN model.

Example

import torch.nn as nn
import torch_geometric as pyg
import torch_geometric.nn as pygnn
from torch_geometric.utils import from_smiles

# Simple one layer GCN + pool + linear -> out
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.gcn = pygnn.GCNConv(in_channels=9, out_channels=16)
        self.lin = nn.Linear(in_features=16, out_features=1)

    def forward(self, batch):
        x = self.gcn(batch.x, batch.edge_index)
        x = pygnn.global_add_pool(x, batch=batch.batch)
        return self.lin(x)

# Init the model
net = Model()

# Read in the smiles string
smiles = from_smiles('CCCCCC')

# Run forward pass
net(smiles)

which raises the following error:

RuntimeError: Found dtype Long but expected Float

Alternatives

Current alternative

Right now, the best alternative is to manually convert the dtype of the x and edge_attr attributes of the generated data object to floats, as shown below. This can, of course, be done during initial parsing of the SMILES string, as a transform in the dataset/dataloader, or on the forward pass of the model.

import torch
import torch.nn as nn
import torch_geometric as pyg
import torch_geometric.nn as pygnn
from torch_geometric.utils import from_smiles

# Simple one layer GCN + pool + linear -> out
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.gcn = pygnn.GCNConv(in_channels=9, out_channels=16)
        self.lin = nn.Linear(in_features=16, out_features=1)

    def forward(self, batch):
        x = self.gcn(batch.x, batch.edge_index)
        x = pygnn.global_add_pool(x, batch=batch.batch)
        return self.lin(x)

# Init the model
net = Model()

# Read in the smiles string
smiles = from_smiles('CCCCCC')

# Convert x to float
smiles.x = smiles.x.float()

# Run forward pass
net(smiles)

which returns:

tensor([[4.8576]], grad_fn=<AddmmBackward0>)

Implementation alternative

One alternative, if it's desirable to keep the current behavior for reverse compatibility is to allow the user to specify the dtype for x and edge_attr through a dtype keyword as shown in the condensed example below.

def from_smiles(smiles: str, with_hydrogen: bool = False, kekulize: bool = False, dtype = torch.long):
    ...
    x = torch.tensor(xs, dtype=torch.dtype).view(-1, 9)
    ...
    edge_attr = torch.tensor(edge_attrs, dtype=dtype).view(-1, 3)
    ...
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)

Additional context

Thank you for considering this change and if I am missing the reason for the long return type, please do let me know.

rusty1s commented 1 year ago

The returned features are categorical. Simply converting them to floating point will wrongly treat them as numerical. What you need to do is input the features into an Embedding layer to learn representations for each category, similar to what we do here. Alternatively, I am open to letting from_smiles return one-hot vector representations, e.g., via from_smiles(one_hot=True). Let me know if you have interest in contributing this.

TShimko126 commented 1 year ago

Hi Matthias - thanks for clarifying! Good to know about the AtomEncoder and BondEncoder from OGB. That should be sufficient for my purposes.

I've previously used the the MolGraphConvFeaturizer from DeepChem which returns one-hot encodings. It may be something I would be open to add in the future, but unfortunately I do not have the time currently. I'll go ahead and close this issue for now.

Thanks for your help and continued work on Pytorch Geometric!