pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.42k stars 3.68k forks source link

Support SparseTensor in torch script #7706

Open Joeyzhouqihui opened 1 year ago

Joeyzhouqihui commented 1 year ago

🚀 The feature, motivation and pitch

I am trying to deploy a pyg model with libtorch and use torch script to trace the model. I mainly use GINConv. When the index is in coo format, I can manage to do that. But I want to use spmm to speedup the computation, so I use SparseTensor, but it is not supported yet.

image

Alternatives

No response

Additional context

No response

rusty1s commented 1 year ago

How does your model look like? Note that you need to add type hints to the forward function, e.g. def forward(x: Tensor, adj: SparseTensor) as otherwise every argument is considered to be a PyTorch tensor.