pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.38k stars 3.67k forks source link

MessagePassing.jittable() remaining features and planning. #1283

Open lgray opened 4 years ago

lgray commented 4 years ago

This is an issue to track followup on #1256, #1257, #1258, #1259.

We can annotate below with further issue and PR #s, as progress is made. Let's use this issue as a forum for discussion about what we want out of the jittable() interface.

During review there were a few functionality issues, or requests that came up (box checked if completed, with PR reference):

Here's a list of jittable convolutional ops (box is checked if tested and confirmed):

lgray commented 4 years ago

@liaopeiyuan @pierthodo

lgray commented 4 years ago

@rusty1s you already changed "return a jittable copy of the conv layer instead of the new class" in one of the two PRs submitted so far? I didn't catch it if you did, could you put a PR or commit reference?

rusty1s commented 4 years ago

See https://github.com/rusty1s/pytorch_geometric/commit/dc87faa6a9ae1b71dbbd6d34f834dc7053803fe6

rusty1s commented 4 years ago

Hi everyone, I worked on a follow-up PR for the JIT interface in https://github.com/rusty1s/pytorch_geometric/pull/1309 (ready to merge), with the following features:

All tests pass, but bipartite graph support/SparseTensors are only added for a couple of convs for now. I would like to merge this PR first before continuing working on it.

Here is a basic example:

class MyConv(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int):
        """"""
        super(MyConv, self).__init__(aggr='add')

        self.lin_l = Linear(in_channels, out_channels)
        self.lin_r = Linear(in_channels, out_channels)

    def forward(self, x: Tensor,
                edge_index: Union[Tensor, SparseTensor]) -> Tensor:
        """"""
        # propagate_type: (x: Tensor)
        out = self.propagate(edge_index, x=x, size=None)
        return self.lin_l(out) + self.lin_r(x)

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

And here is a more complex one that supports bipartite-graphs.

class MyConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int):
        """"""
        super(MyConv, self).__init__(aggr='add')

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels)
        self.lin_r = Linear(in_channels[1], out_channels)

    def forward(self,
                x: Union[Tensor, Tuple[Tensor, OptTensor]],
                edge_index: Union[Tensor, SparseTensor],
                edge_weight: OptTensor = None,
                size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: Tuple[Tensor, OptTensor] = (x, x)

        # propagate_type: (x: Tuple[Tensor, OptTensor], edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if x_r is not None:
            out += self.lin_r(x_r)

        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: Tuple[Tensor, OptTensor]) -> Tensor:
        return matmul(adj_t, x[0], reduce=self.aggr)

Let me know what you think!

lgray commented 4 years ago

Hey - this is really cool, but I am not a big fan of parsing code comments to drive functionality. Too much room for error, especially for users who might be new. Though given the constraints we have, it's a solid solution.

Maybe a better way to do this would be to specify a __propagate_signatures__ member variable containing a list of dicts of type hints.

class MyConv(MessagePassing):
    def __init__(args...):
        ...
        self.__propagate_signatures__ = [
            { 'x': Tensor, 'edge_index': Tensor, ...}
        ]

This way it naturally grows when we need to overload things, and is really properly part of the python, and still gets rid of the tracing entirely.

What do you think?

rusty1s commented 4 years ago

I think so, too. One reason why I opted for the current approach was because Python has a really similar type definition interface with # type: (...) -> .... Maybe I can implement both.

lgray commented 4 years ago

Ah, I thought they had deprecated that in favor of the newer system!

Let's go for both and see which one ends up feeling better in practice?

rusty1s commented 4 years ago

Done :)

chriss2401 commented 4 years ago

Hi @rusty1s and @lgray ,

This new jit functionality is great! Do you guys also plan onnx support ? As far as I have seen there are issues converting a jittable model to onnx (https://github.com/pytorch/pytorch/issues/34002), and ideally I'd like to export a pytorch_geometric model to onnx.

lgray commented 4 years ago

@chriss2401 ONNX is (quite) a bit less flexible than TorchScript, and would similarly need additional C++ bindings written to get all the ops into ONNX. I am not sure of the long term plans of pytorch to keep supporting ONNX nor to what degree they'd cover onnx features. Similarly, TF is also going with their own serialization / jit format more recently.

Another way to approach this would be, what are you trying to achieve with conversion to ONNX? There may be a TorchScript friendly way to achieve the same thing.

Going back to adding the ops to ONNX:

chriss2401 commented 4 years ago

@lgray thanks for the quick answer. I need to deploy my application in C#, and microsoft has created a really nice library called onnxruntime for handling pre-trained models that are converted to onnx for fast inference. What I like about onnx is that it is agnostic to the training framework.

I am open to alternatives, but the only other thing I can think of at this point is to make a C# wrapper on top of PyTorch's C++ LibTorch library to run the jittable models (which is not that bad of a solution, just needs some back and forth between c++ and c#).

lgray commented 4 years ago

@chriss2401 There appears to already be a somewhat mature solution for that: https://github.com/xamarin/TorchSharp

chriss2401 commented 4 years ago

@lgray great, I will try and one run one of your test jittable models using this. Thanks!

lgray commented 4 years ago

@chriss2401 just to make sure you don't run into problems - you'll have to build the libtorchscatter/cluster/etc. libraries yourself via cmake for those packages. Make sure the script module can load them, I think they just need to be in LD_LIBRARY_PATH and you're good.

dongb5 commented 3 years ago

@chriss2401 ONNX is (quite) a bit less flexible than TorchScript, and would similarly need additional C++ bindings written to get all the ops into ONNX. I am not sure of the long term plans of pytorch to keep supporting ONNX nor to what degree they'd cover onnx features. Similarly, TF is also going with their own serialization / jit format more recently.

Another way to approach this would be, what are you trying to achieve with conversion to ONNX? There may be a TorchScript friendly way to achieve the same thing.

Going back to adding the ops to ONNX:

  • it's probably best if those are contributed by those who need it and then they are refined prior to merging. Otherwise it goes onto a long backlog of things that need to get done. :-)

@lgray Hi, sorry for replying this old issue, but I am also looking for a way to ONNX of PyG, because I need TensorRT supports. Based on tests on other network sturctures, inferences on TensorRT always times faster than ones on LibTorch. I do not know whether this results also fit GNN. Or, is there any advice for TensorRT support?

rusty1s commented 3 years ago

Can you tell me more about some of the issues converting PyG models to ONNX? Which ops are not supported yet? We are also working on integrating torch-scatter directly to PyTorch, which might also lead to ONNX support in the long run.

dongb5 commented 3 years ago

Can you tell me more about some of the issues converting PyG models to ONNX? Which ops are not supported yet? We are also working on integrating torch-scatter directly to PyTorch, which might also lead to ONNX support in the long run.

@rusty1s Hi, sorry but I did not try any conversion yet. I tried to confirm the ONNX support of PyG and found some discussion about it. But I did not find any clear clue of ONNX support.

Do you mean that PyG model now can be converted to ONNX? If so I'll take some time to have a try and report the result.

rusty1s commented 3 years ago

I think this depends on torch.scatter_add being supported by ONNX or not (which should give ONNX support for all PyG operators with add or mean aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.

EJShim commented 2 years ago

@rusty1s

Are torch_scatter.scatter and torch.scatter_add same operation?

I have problem converting my model contains torch_scatter.scatter, but If I change it into torch.scatter_add, the output shape does not match, so I cannot test if torch.scatter_add supports onnx conversion.

dongb5 commented 2 years ago

I think this depends on torch.scatter_add being supported by ONNX or not (which should give ONNX support for all PyG operators with add or mean aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.

@rusty1s I think I'm stuck by another point which is mini-batch inference... In my understanding, PyG processes mini-batch as a "giant graph that holds multiple isolated subgraphs", and supports arbitrary numbers of nodes for each sample. Since my final destination is TensorRT, batches in PyG may be treated as single samples with highly size variation. This may cause weird problems in TensorRT inference.

Nevertheless, I have successfully tried converting some simple model, including an example message passing model, to ONNX model. Seems that the operations are supported by ONNX at least. Furthermore, I think the TensorRT inference is possible if the data is highly structured. For example, all samples have same number of nodes, the connections of nodes are same, and only the features of nodes are different, thus we can organize the data as traditional mini-batch form.

Don't know if I understand it right.

rusty1s commented 2 years ago

I sadly don't have experience with TensorRT. If there exists a limitation that mini-batches need to be of same size in the number of nodes and edges, then this is sadly not really well-suited for graph representation learning :(

dongb5 commented 2 years ago

I sadly don't have experience with TensorRT. If there exists a limitation that mini-batches need to be of same size in the number of nodes and edges, then this is sadly not really well-suited for graph representation learning :(

Yes that's what I found so far. Don't know if there is any trick in TensorRT, or if LibTorch can handle this. Maybe NVIDIA can give some helps :P

rusty1s commented 2 years ago

Is TensorRT restricted to use a PyTorch DataLoader with drop_last=True?

dongb5 commented 2 years ago

Is TensorRT restricted to use a PyTorch DataLoader with drop_last=True?

No, TensorRT does nothing with DataLoader of PyTorch. It only takes definitions and weights of layers, shapes of inputs and outputs, then converts them to a surpported and optimized format.

SonyPony commented 2 years ago

I think this depends on torch.scatter_add being supported by ONNX or not (which should give ONNX support for all PyG operators with add or mean aggregation). I haven't done any experiments with that either, so I'm looking forward to your findings.

@rusty1s I think I'm stuck by another point which is mini-batch inference... In my understanding, PyG processes mini-batch as a "giant graph that holds multiple isolated subgraphs", and supports arbitrary numbers of nodes for each sample. Since my final destination is TensorRT, batches in PyG may be treated as single samples with highly size variation. This may cause weird problems in TensorRT inference.

Nevertheless, I have successfully tried converting some simple model, including an example message passing model, to ONNX model. Seems that the operations are supported by ONNX at least. Furthermore, I think the TensorRT inference is possible if the data is highly structured. For example, all samples have same number of nodes, the connections of nodes are same, and only the features of nodes are different, thus we can organize the data as traditional mini-batch form.

Don't know if I understand it right.

@dongb5 check it the prediction is correct. I've been struggling with exporting the GN model to ONNX as well. Although the torch.onnx.export runs without an error, the prediction is not correct due to the bad translation of torch.scatter_add to ONNX graph (https://github.com/pytorch/pytorch/issues/32960) if the indices are not unique. Despite being currently supported by ONNX in opset 16 (https://github.com/onnx/onnx/pull/3484), PyTorch currently supports only opset 14.

ramkrishna1121 commented 2 years ago

@dongb5 @SonyPony Can you please share more details/ examples on which GN models from PyG you were able to export to ONNX?