rusty1s / pytorch_sparse

PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations
MIT License
1.01k stars 147 forks source link

Incompatibilty SparseTensor and TorchScript #242

Closed osllogon closed 2 years ago

osllogon commented 2 years ago

I have saved the following model with TorchScript:

class CrystalModel(torch.nn.Module):

def __init__(self, dim_node_features, dim_edge_features):

    super().__init__()

    self.gnn = torch_geometric.nn.CGConv(dim_node_features, dim_edge_features).jittable()
    self.mlp = torch.nn.Sequential(
        torch.nn.Linear(dim_node_features, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 1)
    )

def forward(self, x, adj):
    outputs = self.gnn(x, adj)
    outputs = self.mlp(outputs)
    return outputs

but then when I tried to run it, as before saving it, using a SparseTensor the following error appears:

Traceback (most recent call last):
  File "/home/elloosc/Projects/reasoning_for_recommenders/oscar/pytorch_training/prueba.py", line 34, in <module>
    outputs1 = model_loaded(X.float(), A.detach().float())
  File "/home/elloosc/envs/urano/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: forward() Expected a value of type 'Tensor' for argument 'adj' but instead found type 'SparseTensor'.
Position: 2
Value: SparseTensor(row=tensor([     0,      0,      0,  ..., 228936, 228937, 228937]),
             col=tensor([     1,      2,      3,  ..., 113113, 107362, 113114]),
             val=tensor([[0.1355, 0.0000, 0.0000, 0.0000, 0.0000],
                           [0.1592, 0.0000, 0.0000, 0.0000, 0.0000],
                           [0.0370, 0.0605, 0.0000, 0.0000, 0.0000],
                           ...,
                           [0.0541, 0.0300, 0.0000, 0.0667, 0.0069],
                           [0.0715, 0.0260, 1.0000, 0.0714, 0.0085],
                           [0.0540, 0.0286, 0.0000, 0.0667, 0.0057]]),
             size=(228938, 228938, 5), nnz=2449723, density=0.00%)
Declaration: forward(__torch__.___torch_mangle_3.CrystalModel self, Tensor x, Tensor adj) -> (Tensor)
Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)
rusty1s commented 2 years ago
def forward(self, x: Tensor, adj: SparseTensor) -> Tensor:

should fix this.

osllogon commented 2 years ago

Yes that fixed it, thank you very much. Also thank you for creating this awesome library