Closed osllogon closed 2 years ago
I have saved the following model with TorchScript:
class CrystalModel(torch.nn.Module):
def __init__(self, dim_node_features, dim_edge_features): super().__init__() self.gnn = torch_geometric.nn.CGConv(dim_node_features, dim_edge_features).jittable() self.mlp = torch.nn.Sequential( torch.nn.Linear(dim_node_features, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1) ) def forward(self, x, adj): outputs = self.gnn(x, adj) outputs = self.mlp(outputs) return outputs
but then when I tried to run it, as before saving it, using a SparseTensor the following error appears:
Traceback (most recent call last): File "/home/elloosc/Projects/reasoning_for_recommenders/oscar/pytorch_training/prueba.py", line 34, in <module> outputs1 = model_loaded(X.float(), A.detach().float()) File "/home/elloosc/envs/urano/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) RuntimeError: forward() Expected a value of type 'Tensor' for argument 'adj' but instead found type 'SparseTensor'. Position: 2 Value: SparseTensor(row=tensor([ 0, 0, 0, ..., 228936, 228937, 228937]), col=tensor([ 1, 2, 3, ..., 113113, 107362, 113114]), val=tensor([[0.1355, 0.0000, 0.0000, 0.0000, 0.0000], [0.1592, 0.0000, 0.0000, 0.0000, 0.0000], [0.0370, 0.0605, 0.0000, 0.0000, 0.0000], ..., [0.0541, 0.0300, 0.0000, 0.0667, 0.0069], [0.0715, 0.0260, 1.0000, 0.0714, 0.0085], [0.0540, 0.0286, 0.0000, 0.0667, 0.0057]]), size=(228938, 228938, 5), nnz=2449723, density=0.00%) Declaration: forward(__torch__.___torch_mangle_3.CrystalModel self, Tensor x, Tensor adj) -> (Tensor) Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)
def forward(self, x: Tensor, adj: SparseTensor) -> Tensor:
should fix this.
Yes that fixed it, thank you very much. Also thank you for creating this awesome library
I have saved the following model with TorchScript:
class CrystalModel(torch.nn.Module):
but then when I tried to run it, as before saving it, using a SparseTensor the following error appears: