dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.39k stars 3k forks source link

[RFC] Compiling user-defined message/reduce functions with `torch.fx` #3627

Open yzh119 opened 2 years ago

yzh119 commented 2 years ago

Motivation

DGL modules written w/ UDFs (User-Defined Functions) suffers severe performance issue, though we already recommend users to write their modules w/ built-in functions. Built-in functions are somewhat not intuitive and there are many legacy codes written w/ DGL UDFs.

There are a bunch of papers on compiling message-passing UDFs to existing sparse kernels, e.g. Seastar, and Graphiler, Seastar was based on MindSpore ecosystem, and Graphiler was written upon Torchscript to manipulate the computational graph. However, Torchscript was not designed to transform IRs, and Graphiler needs user to change the syntax of their UDFs slightly.

Recently, torch.fx has brought a lot of attention because of its capability in symbolic tracing and transforming nn modules written in torch. There exists some work on quantization and kernel fusion w/ torch.fx. More examples are available here.

Pitch

I suppose torch.fx is a great fit for us to implement graphiler in DGL, for the following reasons:

  1. Easy to use: most modules are exposed in Python side and well documented (see GraphModule and Transformer), and we can easily manipulate the IR in pure Python.
  2. torch.fx supports symbolic transformation without knowing about the input, we just need to override a torch.fx.Transformer module to perform the mp-UDF compilation; another mode Interpreter requires input tensor, which we can ignore for now.
  3. Highly configurable, to support a torch NN module w/ DGLGraph's, a feasible solution is to customize tracer. We can also enable graph-aware tracing, to enable more possible optimizations (e.g. AOT kernel-fusion and graph-aware kernel compilation w/ TVM TensorIR).

I'm working on a prototype PR on how to customize a tracer that recognizes DGLGraph and compile a simple GCN module written in UDFs.

Discussions are welcomed @VoVAllen @BarclayII @Rhett-Ying @jermainewang @xiezhq-hermann .

VoVAllen commented 2 years ago

Why torch.fx is preferrable over TorchScript in UDF scenario?

yzh119 commented 2 years ago

I'm not sure if TorchScript is capable of customizing tracer, @VoVAllen I remember you have done some work on JIT DGL models?

yzh119 commented 2 years ago

btw, I believe pytorch geometric is using torch.fx to handle heterographs:

They can easily extend model defined for homogeneous graphs to heterogeneous graphs with a single to_hetero API, which is elegant IMO.

jermainewang commented 2 years ago

TorchScript does support custom data type so theoretically it can support compiling a program involving DGLGraphs. I think fx and script are designed for different purposes. FX is a source code transpiler so it will be suitable for scenarios that are fine with Python runtime. What you've listed are good examples. Torchscript can lower a program into script IR to completely by-pass python runtime, which is suitable if ultimate efficiency is needed. I think DGL shall be compatible with both, but the Torchscript route seem to need more effort than TorchFX at the moment.

For optimizing UDFs, I agree the current bottleneck can be mostly solved via source code translation, so FX is definitely interesting here.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you