Open lgray opened 4 years ago
@liaopeiyuan @pierthodo
@rusty1s you already changed "return a jittable copy of the conv layer instead of the new class" in one of the two PRs submitted so far? I didn't catch it if you did, could you put a PR or commit reference?
Hi everyone, I worked on a follow-up PR for the JIT interface in https://github.com/rusty1s/pytorch_geometric/pull/1309 (ready to merge), with the following features:
jittable
does not need to trace anymore, instead users specify the types passed to propagate
explicitly, e.g., via # propagate_type: (x: Tensor, edge_weight: Optional[Tensor])
Union
. Note that Union
is not naturally supported by PyTorch. Instead I cast each Union
combiniation into its own @overload
type.All tests pass, but bipartite graph support/SparseTensors are only added for a couple of convs for now. I would like to merge this PR first before continuing working on it.
Here is a basic example:
class MyConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int):
""""""
super(MyConv, self).__init__(aggr='add')
self.lin_l = Linear(in_channels, out_channels)
self.lin_r = Linear(in_channels, out_channels)
def forward(self, x: Tensor,
edge_index: Union[Tensor, SparseTensor]) -> Tensor:
""""""
# propagate_type: (x: Tensor)
out = self.propagate(edge_index, x=x, size=None)
return self.lin_l(out) + self.lin_r(x)
def message(self, x_j: Tensor) -> Tensor:
return x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
And here is a more complex one that supports bipartite-graphs.
class MyConv(MessagePassing):
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int):
""""""
super(MyConv, self).__init__(aggr='add')
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_l = Linear(in_channels[0], out_channels)
self.lin_r = Linear(in_channels[1], out_channels)
def forward(self,
x: Union[Tensor, Tuple[Tensor, OptTensor]],
edge_index: Union[Tensor, SparseTensor],
edge_weight: OptTensor = None,
size: Size = None) -> Tensor:
""""""
if isinstance(x, Tensor):
x: Tuple[Tensor, OptTensor] = (x, x)
# propagate_type: (x: Tuple[Tensor, OptTensor], edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=size)
out = self.lin_l(out)
x_r = x[1]
if x_r is not None:
out += self.lin_r(x_r)
return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)
def message_and_aggregate(self, adj_t: SparseTensor,
x: Tuple[Tensor, OptTensor]) -> Tensor:
return matmul(adj_t, x[0], reduce=self.aggr)
Let me know what you think!
Hey - this is really cool, but I am not a big fan of parsing code comments to drive functionality. Too much room for error, especially for users who might be new. Though given the constraints we have, it's a solid solution.
Maybe a better way to do this would be to specify a __propagate_signatures__
member variable containing a list of dicts of type hints.
class MyConv(MessagePassing):
def __init__(args...):
...
self.__propagate_signatures__ = [
{ 'x': Tensor, 'edge_index': Tensor, ...}
]
This way it naturally grows when we need to overload things, and is really properly part of the python, and still gets rid of the tracing entirely.
What do you think?
I think so, too. One reason why I opted for the current approach was because Python has a really similar type definition interface with # type: (...) -> ...
. Maybe I can implement both.
Ah, I thought they had deprecated that in favor of the newer system!
Let's go for both and see which one ends up feeling better in practice?
Done :)
Hi @rusty1s and @lgray ,
This new jit functionality is great! Do you guys also plan onnx support ? As far as I have seen there are issues converting a jittable model to onnx (https://github.com/pytorch/pytorch/issues/34002), and ideally I'd like to export a pytorch_geometric model to onnx.
@chriss2401 ONNX is (quite) a bit less flexible than TorchScript, and would similarly need additional C++ bindings written to get all the ops into ONNX. I am not sure of the long term plans of pytorch to keep supporting ONNX nor to what degree they'd cover onnx features. Similarly, TF is also going with their own serialization / jit format more recently.
Another way to approach this would be, what are you trying to achieve with conversion to ONNX? There may be a TorchScript friendly way to achieve the same thing.
Going back to adding the ops to ONNX:
@lgray thanks for the quick answer. I need to deploy my application in C#, and microsoft has created a really nice library called onnxruntime for handling pre-trained models that are converted to onnx for fast inference. What I like about onnx is that it is agnostic to the training framework.
I am open to alternatives, but the only other thing I can think of at this point is to make a C# wrapper on top of PyTorch's C++ LibTorch
library to run the jittable models (which is not that bad of a solution, just needs some back and forth between c++ and c#).
@chriss2401 There appears to already be a somewhat mature solution for that: https://github.com/xamarin/TorchSharp
@lgray great, I will try and one run one of your test jittable models using this. Thanks!
@chriss2401 just to make sure you don't run into problems - you'll have to build the libtorchscatter/cluster/etc. libraries yourself via cmake for those packages. Make sure the script module can load them, I think they just need to be in LD_LIBRARY_PATH and you're good.
@chriss2401 ONNX is (quite) a bit less flexible than TorchScript, and would similarly need additional C++ bindings written to get all the ops into ONNX. I am not sure of the long term plans of pytorch to keep supporting ONNX nor to what degree they'd cover onnx features. Similarly, TF is also going with their own serialization / jit format more recently.
Another way to approach this would be, what are you trying to achieve with conversion to ONNX? There may be a TorchScript friendly way to achieve the same thing.
Going back to adding the ops to ONNX:
- it's probably best if those are contributed by those who need it and then they are refined prior to merging. Otherwise it goes onto a long backlog of things that need to get done. :-)
@lgray Hi, sorry for replying this old issue, but I am also looking for a way to ONNX of PyG, because I need TensorRT supports. Based on tests on other network sturctures, inferences on TensorRT always times faster than ones on LibTorch. I do not know whether this results also fit GNN. Or, is there any advice for TensorRT support?
Can you tell me more about some of the issues converting PyG models to ONNX? Which ops are not supported yet? We are also working on integrating torch-scatter
directly to PyTorch, which might also lead to ONNX support in the long run.
Can you tell me more about some of the issues converting PyG models to ONNX? Which ops are not supported yet? We are also working on integrating
torch-scatter
directly to PyTorch, which might also lead to ONNX support in the long run.
@rusty1s Hi, sorry but I did not try any conversion yet. I tried to confirm the ONNX support of PyG and found some discussion about it. But I did not find any clear clue of ONNX support.
Do you mean that PyG model now can be converted to ONNX? If so I'll take some time to have a try and report the result.
I think this depends on torch.scatter_add
being supported by ONNX or not (which should give ONNX support for all PyG operators with add
or mean
aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.
@rusty1s
Are torch_scatter.scatter and torch.scatter_add same operation?
I have problem converting my model contains torch_scatter.scatter, but If I change it into torch.scatter_add, the output shape does not match, so I cannot test if torch.scatter_add supports onnx conversion.
I think this depends on
torch.scatter_add
being supported by ONNX or not (which should give ONNX support for all PyG operators withadd
ormean
aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.
@rusty1s I think I'm stuck by another point which is mini-batch inference... In my understanding, PyG processes mini-batch as a "giant graph that holds multiple isolated subgraphs", and supports arbitrary numbers of nodes for each sample. Since my final destination is TensorRT, batches in PyG may be treated as single samples with highly size variation. This may cause weird problems in TensorRT inference.
Nevertheless, I have successfully tried converting some simple model, including an example message passing model, to ONNX model. Seems that the operations are supported by ONNX at least. Furthermore, I think the TensorRT inference is possible if the data is highly structured. For example, all samples have same number of nodes, the connections of nodes are same, and only the features of nodes are different, thus we can organize the data as traditional mini-batch form.
Don't know if I understand it right.
I sadly don't have experience with TensorRT. If there exists a limitation that mini-batches need to be of same size in the number of nodes and edges, then this is sadly not really well-suited for graph representation learning :(
I sadly don't have experience with TensorRT. If there exists a limitation that mini-batches need to be of same size in the number of nodes and edges, then this is sadly not really well-suited for graph representation learning :(
Yes that's what I found so far. Don't know if there is any trick in TensorRT, or if LibTorch can handle this. Maybe NVIDIA can give some helps :P
Is TensorRT restricted to use a PyTorch DataLoader
with drop_last=True
?
Is TensorRT restricted to use a PyTorch
DataLoader
withdrop_last=True
?
No, TensorRT does nothing with DataLoader
of PyTorch. It only takes definitions and weights of layers, shapes of inputs and outputs, then converts them to a surpported and optimized format.
I think this depends on
torch.scatter_add
being supported by ONNX or not (which should give ONNX support for all PyG operators withadd
ormean
aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.@rusty1s I think I'm stuck by another point which is mini-batch inference... In my understanding, PyG processes mini-batch as a "giant graph that holds multiple isolated subgraphs", and supports arbitrary numbers of nodes for each sample. Since my final destination is TensorRT, batches in PyG may be treated as single samples with highly size variation. This may cause weird problems in TensorRT inference.
Nevertheless, I have successfully tried converting some simple model, including an example message passing model, to ONNX model. Seems that the operations are supported by ONNX at least. Furthermore, I think the TensorRT inference is possible if the data is highly structured. For example, all samples have same number of nodes, the connections of nodes are same, and only the features of nodes are different, thus we can organize the data as traditional mini-batch form.
Don't know if I understand it right.
@dongb5 check it the prediction is correct. I've been struggling with exporting the GN model to ONNX as well. Although the torch.onnx.export
runs without an error, the prediction is not correct due to the bad translation of torch.scatter_add
to ONNX graph (https://github.com/pytorch/pytorch/issues/32960) if the indices are not unique. Despite being currently supported by ONNX in opset 16
(https://github.com/onnx/onnx/pull/3484), PyTorch currently supports only opset 14
.
@dongb5 @SonyPony Can you please share more details/ examples on which GN models from PyG you were able to export to ONNX?
This is an issue to track followup on #1256, #1257, #1258, #1259.
We can annotate below with further issue and PR #s, as progress is made. Let's use this issue as a forum for discussion about what we want out of the
jittable()
interface.During review there were a few functionality issues, or requests that came up (box checked if completed, with PR reference):
Here's a list of jittable convolutional ops (box is checked if tested and confirmed):
propagate
types