dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.38k stars 3k forks source link

JAX backend? #2260

Open yuanqing-wang opened 3 years ago

yuanqing-wang commented 3 years ago

🚀 Feature

Is JAX backend on your roadmap? Given the increasing popularity and XLA features I think it would be very interesting to include JAX.

Motivation

Alternatives

Pitch

Additional context

yzh119 commented 3 years ago

We are not aware if JAX supports dlpack, if so, supporting JAX would be easy.

For XLA features, as sparse matrix is not well supported by XLA, I'm afraid we cannot benefit too much from it.

yuanqing-wang commented 3 years ago

I think JAX does support dlpack.

https://jax.readthedocs.io/en/latest/jax.dlpack.html

I think there's another challenge---the differentiation operation is based on functions rather than tensors, so grammars like backward and tensor.grad would be quite different.

yzh119 commented 3 years ago

Yes you are right, Jax's autograd is based on SCT rather than tracing, a function would be translated into jaxpr and transformed into its gradient function. It's not trivial to deal with complex data structures such as DGLGraphs in jaxpr, a much feasible approach is to write gnn program's in pure tensor style, such as dgl.ops where the graphs would be replaced with their corresponding coo/csr representations.

Jax support custom_vjp which is quite similar to PyTorch's AutoGrad function, so supporting message passing operators in Jax is not a burden.

yuanqing-wang commented 3 years ago

Ah ok I see!

Sounds like it's doable despite some challenges!

I would actually be very interested in seeing JAX backend in dgl since it would be tremendously useful in the biophysical modeling community where I come from---JAX has been proven to be able to support gradient computation along long, unrolled Langevin Dynamics simulations and a graph structure representation layer would enable so many interesting applications from parameter fitting to enhanced sampling.

Anyway, I started to play with this idea on a personal branch (https://github.com/yuanqing-wang/dgl). I'll see how far I can go.

VoVAllen commented 3 years ago

Thanks for your interest. If you met any problems, please feel free to raise issues and we would be glad to help. And I think we can make the issue open for track in case that other people may also be interested in the same question.

yuanqing-wang commented 3 years ago

Thanks! @VoVAllen

The first blocking issue I had is that zerocopy_to_dlpack sometimes complains about the buffer being deleted.

This behavior is not observed when I make a copy of the jax.DeviceArray whenever copied to dlpack, but this is not zerocopy any more, of course.

yuanqing-wang commented 3 years ago

I'm still trying to chase down the source of this error. But I assume that we'll never zerocopy_to_dlpack from one tensor twice? Although somehow doing so with torch does not cause errors.

yuanqing-wang commented 3 years ago
_______________________________________________ test_batch_setter_getter[int32] ________________________________________________

idtype = <class 'jax.numpy.lax_numpy.int32'>

    @parametrize_dtype
    def test_batch_setter_getter(idtype):
        def _pfc(x):
            return list(F.zerocopy_to_numpy(x)[:,0])
        g = generate_graph(idtype)
        # set all nodes
        g.ndata['h'] = F.zeros((10, D))
        assert F.allclose(g.ndata['h'], F.zeros((10, D)))
        # pop nodes
        old_len = len(g.ndata)
        g.ndata.pop('h')
        assert len(g.ndata) == old_len - 1
        g.ndata['h'] = F.zeros((10, D))
        # set partial nodes
        u = F.tensor([1, 3, 5], g.idtype)
        g.nodes[u].data['h'] = F.ones((3, D))
        assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
        # get partial nodes
        u = F.tensor([1, 2, 3], g.idtype)
        assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]

        '''
        s, d, eid
        0, 1, 0
        1, 9, 1
        0, 2, 2
        2, 9, 3
        0, 3, 4
        3, 9, 5
        0, 4, 6
        4, 9, 7
        0, 5, 8
        5, 9, 9
        0, 6, 10
        6, 9, 11
        0, 7, 12
        7, 9, 13
        0, 8, 14
        8, 9, 15
        9, 0, 16
        '''
        # set all edges
        g.edata['l'] = F.zeros((17, D))
        assert _pfc(g.edata['l']) == [0.] * 17
        # pop edges
        old_len = len(g.edata)
        g.edata.pop('l')
        assert len(g.edata) == old_len - 1
        g.edata['l'] = F.zeros((17, D))
        # set partial edges (many-many)
        u = F.tensor([0, 0, 2, 5, 9], g.idtype)
        v = F.tensor([1, 3, 9, 9, 0], g.idtype)
>       g.edges[u, v].data['l'] = F.ones((5, D))

compute/test_basics.py:144: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../python/dgl/view.py:197: in __setitem__
    self._graph._set_e_repr(self._etid, self._edges, {key : val})
../python/dgl/heterograph.py:3886: in _set_e_repr
    eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
../python/dgl/utils/checks.py:97: in parse_edges_arg_to_eid
    eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid])
../python/dgl/heterograph.py:2830: in edge_ids
    eid = self._graph.edge_ids_one(self.get_etype_id(etype), u, v)
../python/dgl/heterograph_index.py:440: in edge_ids_one
    self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)))
../python/dgl/backend/__init__.py:92: in to_dgl_nd
    return zerocopy_to_dgl_ndarray(data)
../python/dgl/backend/jax/tensor.py:429: in zerocopy_to_dgl_ndarray
    return nd.from_dlpack(jax.dlpack.to_dlpack(data))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = DeviceArray(<jaxlib.xla_extension.Buffer object at 0x7ffe7606f1f0>, dtype=int32)

    def to_dlpack(x: xla.DeviceArray):
      """Returns a DLPack tensor that encapsulates a DeviceArray `x`.

      Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
      state.

      Args:
        x: a `DeviceArray`, on either CPU or GPU.
      """
      if not isinstance(x, xla.DeviceArray):
        raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
                        .format(type(x)))
      buf = xla._force(x).device_buffer
>     return xla_client._xla.buffer_to_dlpack_managed_tensor(buf)
E     RuntimeError: Invalid argument: Cannot convert deleted/invalid buffer to DLPack tensor.

../../../../anaconda3/envs/pinot/lib/python3.7/site-packages/jax/dlpack.py:34: RuntimeError
======================================================= warnings summary =======================================================
compute/test_basics.py::test_compatible
  /Users/wangy1/Documents/GitHub/dgl/python/dgl/base.py:45: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.
    return warnings.warn(message, category=category, stacklevel=1)

compute/test_basics.py::test_compatible
  /Users/wangy1/anaconda3/envs/pinot/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
    warnings.warn('No GPU/TPU found, falling back to CPU.')

compute/test_basics.py::test_compatible
  /Users/wangy1/Documents/GitHub/dgl/python/dgl/base.py:45: DGLWarning: DGLGraph.add_edge is deprecated. Please use DGLGraph.add_edges
    return warnings.warn(message, category=category, stacklevel=1)

-- Docs: https://docs.pytest.org/en/latest/warnings.html
=================================================== short test summary info ====================================================
FAILED compute/test_basics.py::test_batch_setter_getter[int32] - RuntimeError: Invalid argument: Cannot convert deleted/inval...
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
=========================================== 1 failed, 1 passed, 3 warnings in 2.86s ============================================

Attaching the traceback.

yuanqing-wang commented 3 years ago

Relavent bits here:

https://github.com/yuanqing-wang/dgl/blob/7037966db04487933a25dbcca9fa72305115f9f4/python/dgl/backend/jax/tensor.py#L423

I also submitted a PR although this WIP is still in very early stage. #2281

VoVAllen commented 3 years ago

It seems weird and probably related the jax's runtime. We may need some time to investigate jax and this issue. Thanks again for your conrtibution!

mattjj commented 3 years ago

Hey, JAX developer here. I was skimming this thread (coming from google/jax#4636) and just wanted to amend something:

Yes you are right, Jax's autograd is based on SCT rather than tracing, a function would be translated into jaxpr and transformed into its gradient function.

Actually, JAX's autodiff is all based on tracing. JAX grew out of the original Autograd, which popularized trace-based autodiff in Python as well as the term "autograd" itself. It's still based on the same design principles, just generalized.

In the context of autodiff, a jaxpr is more or less the same as what other systems refer to as a "dynamic computation graph", i.e. it's the data structure built up by tracing the forward pass, and it's consumed during the backward pass of reverse-mode autodiff.

Thanks for the bug report about DLPack!

yzh119 commented 3 years ago

@mattjj thanks for pointing out, it seems I misunderstood the mechanism of JAX's autograd.

I was expecting jaxpr to be an IR at the same level as relay and mlir, after going through your cookbook I found the control flows defined in python would be inlined, user must explicitly specify control flow (cond/while/scan) in lax to enable them jaxpr level.

I wonder do you have plan to support translating general python functions to jaxpr? I know using traced-based autodiff might be your design choice, but this also being said the training engine highly depends on python runtime (which is reasonable though).

mattjj commented 3 years ago

thanks for pointing out, it seems I misunderstood the mechanism of JAX's autograd

That's pretty understandable since we never document anything about JAX :P

I wonder do you have plan to support translating general python functions to jaxpr?

It's a good question, but we don't plan to add support for staging out general Python functions. Limitations of our simple tracing mechanism are just one side of the coin, but there's also the fact that our staged-out computations (namely XLA HLO programs) can't efficiently represent arbitrary Python code (e.g. Python's complex control flow would have to be mapped to HLO's structured control flow). We're not trying to make arbitrary Python compilable; instead we're optimizing for explicitness and predictability, which is why in lax_control_flow.py we surface structured control flow primitives that closely follow those in XLA HLO. By providing a relatively explicit embedding of XLA HLO in Python, without much magic, expert users can get what they want with few surprises, and maintain a decent mental model of what's efficient. Moreover, folks who want to build more automatic systems (e.g. consuming Python ASTs or something) could in principle do it on top of JAX.

WDYT?

yuanqing-wang commented 3 years ago

@VoVAllen @yzh119 In #2281, I think I managed to pass the majority of the test cases (skipping the ones with .backward() since the grammar is evidently different in JAX) and provided a minimal example for NN backpropagation here: https://github.com/yuanqing-wang/dgl/blob/b79dc0c36eb5ad79979ea7edfa6f382025115206/tests/jax/test_simple_graph_conv.py#L3

yuanqing-wang commented 3 years ago

I was thinking that down the road we might be able to jit the entire thing so I didn't use the internal SPMM and SDDMM but hand-wrote those here: https://github.com/yuanqing-wang/dgl/blob/b79dc0c36eb5ad79979ea7edfa6f382025115206/python/dgl/backend/jax/sparse.py#L131 which is as of now a lot slower.

yuanqing-wang commented 3 years ago

Even with the internal SPMM and SDDMM, the tests are still passing at a much slower speed. I'll look into that in greater detail, but my guess is that it's spending more time transferring back and forth between DLPack.

yzh119 commented 3 years ago

@yuanqing-wang thanks for all the work you have done.

Yes if you don't use internal implementation of our GSpMM and GSDDMM the speed would be much slower (and use more GPU memory correpsondingly), PyTorch has some mechanism to register custom functions in JIT mode (https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#using-the-torchscript-custom-operator-in-python), and I'm not sure if JAX has similar functionalities.

Convert a torch tensor to DLPack (and back) has zero cost (no need to manipulate their memory layout) and I'm not sure if it the same case for JAX, you can use some profiling tools for fine-grained per kernel overhead information.

yuanqing-wang commented 3 years ago

Yeah I think similar things are certainly possible in JAX.

https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

mattjj commented 3 years ago

There are examples of registering custom kernels in the jaxlib directory in the jax repository. That's how we call cuSOLVER, for example (see cusolver.cc in that directory). The mpi4jax repo has some nice minimal examples too. (Sorry for not sharing more helpful links; typing on my phone!)

mattjj commented 3 years ago

If you see speed issues, don't hesitate to open an issue on the JAX issue tracker!

yuanqing-wang commented 3 years ago

@mattjj Sorry for the confusion. I posted a comment about the difference in speed between the jax and torch backend for this benchmark experiment on generalized sparse matrix multiplication: https://github.com/dglai/dgl-0.5-benchmark/blob/master/kernel/dgl-new.py but quickly deleted it since I realized that my claim that there is a huge difference was not accurate. Didn't realize that you already posted here. My apologies for the confusion.

Posting again the results of the comparison:

JAX:

Graph(num_nodes=232965, num_edges=114615892,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=dtype('bool')), 'val_mask': Scheme(shape=(), dtype=dtype('bool')), 'test_mask': Scheme(shape=(), dtype=dtype('bool')), 'feat': Scheme(shape=(602,), dtype=dtype('float32')), 'label': Scheme(shape=(), dtype=dtype('int64'))}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 0.3649095467158726
hidden size: 2, avg time: 0.49075024468558176
hidden size: 4, avg time: 1.3259624753679549
hidden size: 8, avg time: 2.699452911104475
hidden size: 16, avg time: 5.220511300223214
hidden size: 32, avg time: 10.44725513458252
hidden size: 64, avg time: 15.337066343852452
hidden size: 128, avg time: 23.919097457613265
Graph(num_nodes=169343, num_edges=1166243,
      ndata_schemes={'year': Scheme(shape=(1,), dtype=dtype('int64')), 'feat': Scheme(shape=(128,), dtype=dtype('float32'))}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 0.006763935089111328
hidden size: 2, avg time: 0.009878431047712053
hidden size: 4, avg time: 0.013012783867972237
hidden size: 8, avg time: 0.028874874114990234
hidden size: 16, avg time: 0.060634102140154154
hidden size: 32, avg time: 0.12363202231270927
hidden size: 64, avg time: 0.182865994317191
hidden size: 128, avg time: 0.345559869493757
Graph(num_nodes=132534, num_edges=79122504,
      ndata_schemes={'species': Scheme(shape=(1,), dtype=dtype('int64'))}
      edata_schemes={'feat': Scheme(shape=(8,), dtype=dtype('float32'))})
SPMM
----------------------------
hidden size: 1, avg time: 0.23015737533569336
hidden size: 2, avg time: 0.27006326402936665
hidden size: 4, avg time: 0.40189906529017855
hidden size: 8, avg time: 0.6493452617100307
hidden size: 16, avg time: 1.234830379486084
hidden size: 32, avg time: 2.861896276473999
hidden size: 64, avg time: 5.437532561165946
hidden size: 128, avg time: 10.53518911770412

PyTorch:

Graph(num_nodes=232965, num_edges=114615892,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=dtype('bool')), 'val_mask': Scheme(shape=(), dtype=dtype('bool')), 'test_mask': Scheme(shape=(), dtype=dtype('bool')), 'feat': Scheme(shape=(602,), dtype=dtype('float32')), 'label': Scheme(shape=(), dtype=dtype('int64'))}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 0.3580136639731271
hidden size: 2, avg time: 0.43580879483904156
hidden size: 4, avg time: 0.6648304803030831
hidden size: 8, avg time: 1.8532922267913818
hidden size: 16, avg time: 4.275122914995466
hidden size: 32, avg time: 9.102126802716937
hidden size: 64, avg time: 13.36836998803275
hidden size: 128, avg time: 19.27591300010681
Graph(num_nodes=169343, num_edges=1166243,
      ndata_schemes={'year': Scheme(shape=(1,), dtype=torch.int64), 'feat': Scheme(shape=(128,), dtype=torch.float32)}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 0.005928448268345424
hidden size: 2, avg time: 0.007018566131591797
hidden size: 4, avg time: 0.010804789406912667
hidden size: 8, avg time: 0.014909539903913225
hidden size: 16, avg time: 0.04121732711791992
hidden size: 32, avg time: 0.09793118068150111
hidden size: 64, avg time: 0.14246719224112375
hidden size: 128, avg time: 0.22136974334716797
Graph(num_nodes=132534, num_edges=79122504,
      ndata_schemes={'species': Scheme(shape=(1,), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(8,), dtype=torch.float32)})
SPMM
----------------------------
hidden size: 1, avg time: 0.2294766562325614
hidden size: 2, avg time: 0.23444959095546178
hidden size: 4, avg time: 0.3821969713483538
hidden size: 8, avg time: 0.5904900687081474
hidden size: 16, avg time: 1.0325097356523787
hidden size: 32, avg time: 2.2940547125680104
hidden size: 64, avg time: 3.780037062508719
hidden size: 128, avg time: 8.389399664742607
mattjj commented 3 years ago

Ah, thanks for letting me know! And no worries; I just want to make sure JAX is working well for you :) Please drop us a line on our issue tracker as issues arise.

sooheon commented 3 years ago

@yuanqing-wang Sorry, does the above benchmark mean you have a working version of the jax backend? What's currently missing (other than NN implementations)?

yuanqing-wang commented 3 years ago

@sooheon See #2281, I have a rough implementation that passed most of the tests. The biggest challenge right now, I think, is to use the DGL internal implementation for g-SPMM and g-SDDMM while allowing the gradients to be calculated nicely using custom_vjp and custom_jvp functions in JAX.

yuanqing-wang commented 3 years ago

the NN implementations, on the other hand, is a simpler task. Once we have all the kernels working, the only thing left to do is to translate from a torch.nn.Module or tf.Module to a Flax or Haiku module.

yuanqing-wang commented 3 years ago

Posting the benchmark results for my crappy jax-native implementation of gspmm (https://github.com/yuanqing-wang/dgl/blob/587808db955c3289e560e4746b9a8f1b2225eb1e/python/dgl/backend/jax/sparse.py#L200) here for comparison:

Graph(num_nodes=232965, num_edges=114615892,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=dtype('bool')), 'val_mask': Scheme(shape=(), dtype=dtype('bool')), 'test_mask': Scheme(shape=(), dtype=dtype('bool')), 'feat': Scheme(shape=(602,), dtype=dtype('float32')), 'label': Scheme(shape=(), dtype=dtype('int64'))}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 1.072848149708339
hidden size: 2, avg time: 1.2779584612165178
hidden size: 4, avg time: 1.7496437004634313
hidden size: 8, avg time: 2.9859710420881
hidden size: 16, avg time: 3.955272708620344
hidden size: 32, avg time: 6.034405946731567
hidden size: 64, avg time: 15.362382446016584
hidden size: 128, avg time: 24.198069402149745
Graph(num_nodes=169343, num_edges=1166243,
      ndata_schemes={'year': Scheme(shape=(1,), dtype=dtype('int64')), 'feat': Scheme(shape=(128,), dtype=dtype('float32'))}
      edata_schemes={})
SPMM
----------------------------
hidden size: 1, avg time: 0.008335283824375697
hidden size: 2, avg time: 0.010728052684238978
hidden size: 4, avg time: 0.014384269714355469
hidden size: 8, avg time: 0.020336287362234935
hidden size: 16, avg time: 0.030741419110979353
hidden size: 32, avg time: 0.051860298429216654
hidden size: 64, avg time: 0.13714558737618582
hidden size: 128, avg time: 0.2634686401912144
Graph(num_nodes=132534, num_edges=79122504,
      ndata_schemes={'species': Scheme(shape=(1,), dtype=dtype('int64'))}
      edata_schemes={'feat': Scheme(shape=(8,), dtype=dtype('float32'))})
SPMM
----------------------------
hidden size: 1, avg time: 0.6562830039433071
hidden size: 2, avg time: 0.7877050808497837
hidden size: 4, avg time: 0.8968813759940011
hidden size: 8, avg time: 1.1113482883998327
hidden size: 16, avg time: 1.3433971064431327
hidden size: 32, avg time: 2.189688273838588
hidden size: 64, avg time: 6.465246541159494
hidden size: 128, avg time: 11.838832139968872
yuanqing-wang commented 3 years ago

Note that it doesn't seem to be too bad when the dimension is large.

JosephDenman commented 1 year ago

Is there any update on this work? Is it a planned feature?