jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Custom GPU ops #623

Closed proteneer closed 3 months ago

proteneer commented 5 years ago

I'm looking to implement custom GPU ops similar to how tensorflow allows for defining custom jvps. Is there a similar tutorial/guide on how feasible this will be with jax?

mattjj commented 5 years ago

Do you want to be able to call your own hand-written CUDA kernels, or instead do you want to be able to control how some of your functions act under transformations (like forward-mode or reverse-mode differentiation), even if they're just implemented in Python in terms of jax.numpy? Could you give an example?

proteneer commented 5 years ago

Call my own hand-written CUDA kernels.

I have a function F(X(p), p) and I'm interested in using autodiff to compute the total derivative:

dF/dp = dF/dX * dX/dp + dF/dp, where dX/dp is the input grad.

Technically, dF/dx is an NxN symmetric hessian of some energy function E, and dF/dp is a second order mixed partial, the * operator calls into cuBLAS L3 symmetric GEMM to make everything super fast.

In reality, F=B(X(p),p) + N(X(p),p) is the result of multiple different types of forces summed together (bonded, non bonded, etc.). So I've written custom CUDA kernels for each:

dF/dp = dF/dX * dX/dp + dF/dp

can be expanded into

dF/dp = d(B+N)/dX * dX/dp + d(B+N)/dp
dF/dp = (dB/dX+dN/dX) * dX/dp + dB/dp + dN/dp <-- much faster than:
dF/dp = (dB/dX * dX/dp + dB/dp) + (dN/dX * dX/dp + dN/dp)

I have custom ops written for

dB/dX
dN/dX
dB/dp
dN/dp

I'd like to be able to re-use them in JAX as a custom op in conjunction with the GEMM.

So in JVP notation, the J is Jacobian of a sum of different energy functions (B, N, etc.), and the V is supplied dX/dp term.

For an example of what dN/dX and dN/dp look like

https://github.com/proteneer/timemachine/blob/master/timemachine/cpu_functionals/electrostatics.cuh

Note that these kernels are designed to compute both terms simultaneously for speed reasons. I'd be okay with separating them out if needed.

(There's a lot peculiarities in terms of optimizing the calculations to be fully-warp asynchronous by abusing the __shfl instrinsic to death).

proteneer commented 5 years ago

To add: I'm currently using AD systems as a "reference" platform to check my highly optimized production code against. If it were possible to directly incorporate the kernels as custom ops then its one less set of code that I'd have to maintain.

nottombrown commented 5 years ago

I would also like to be able to call hand-written CUDA kernels.

Say for example that I wanted to use some of Scott Gray's efficient blocksparse kernels (https://github.com/openai/blocksparse). Is there a story for how I would get that working in JAX?

mattjj commented 5 years ago

Not yet, but it's possible. We may need an XLA:GPU feature to be further developed though.

There's an XLA HLO named CustomCall that in principle allows jumping into custom code from an XLA computation. @hawkinsp used it with the XLA:CPU backend to set up special CPU-specific translation rules for linear algebra calls so that they jump into LAPACK code on CPU (and currently fall back to HLO-level implementations on GPU and TPU). That same technique works on CPU for jumping into custom cython or whatever. But my very limited understanding is that CustomCall is not yet sufficiently developed on XLA:GPU to support this kind of thing.

I believe we already have a feature request in with XLA:GPU to improve CustomCall so that we can jump into cuSOLVER / MAGMA routines for optimized linear algebra on GPU. I'm fuzzy on the details but I think that would also let JAX expose a way to stitch your custom GPU kernels into an XLA:GPU-compiled program. In the JAX layer you could also attach your own rules for the other transformations, like differentiation, so it'd compose nicely.

@jlebar and @hawkinsp understand this much better and so may be able to shed more light, but I suspect the bottom line is that CustomCall on XLA:GPU needs to be fleshed out, and that the hard-working XLA:GPU team is aware but is also balancing tons of other important work.

mattjj commented 5 years ago

By the way, all this development we're talking about is on open-source code, so if you know any crack GPU developers, contributions welcome! :)

jlebar commented 5 years ago

Implementing proper custom ops in XLA:GPU would be pretty simple in theory (famous last words), but as @mattjj says, we have other higher-importance things on our plate at the moment. I would be happy to advise and review patches if anyone wanted to take this on.

I believe we already have a feature request in with XLA:GPU to improve CustomCall so that we can jump into cuSOLVER / MAGMA routines for optimized linear algebra on GPU.

I believe the plan of record for these specifically is to implement them in XLA itself, i.e. not using CustomCall. This way everyone who wants to use cuSOLVER doesn't have to reimplement these same custom ops.

proteneer commented 5 years ago

In theory if you guys already support XLA:CPU, and we write our own CPU code that then calls into the various CUDA kernels - would that cause issues?

jlebar commented 5 years ago

would that cause issues?

Yes, because a single XLA computation runs on one type of device -- either CPU or a particular GPU model -- so you wouldn't be able to mix them. Also all of the inputs to a CPU custom op would need to be in CPU memory.

On Tue, Apr 23, 2019 at 7:09 AM Yutong Zhao notifications@github.com wrote:

In theory if you guys already support XLA:CPU, and we write our own CPU code that then calls into the various CUDA kernels - would that cause issues?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/623#issuecomment-485819313, or mute the thread https://github.com/notifications/unsubscribe-auth/AABEZBYLLFI5ONQCE6ARIYLPR4KCRANCNFSM4HGYAUDQ .

proteneer commented 5 years ago

A little more about our use case. We have an extremely optimized set of GPU kernels for doing physics simulations of molecules that took us many years to write:

https://github.com/pandegroup/openmm/tree/master/platforms/cuda/src/kernels https://github.com/proteneer/timemachine/tree/master/timemachine/cpu_functionals

Some of these kernels themselves are actually also JITed (on the cuda source level, not PTX level), though we do it purely symbolically on a very restricted set of functional forms.

We'd love to be able to bring all that code into our JAX workflows somehow and it'd definitely help convince more of us old-school physics-types to get more involved with the project. Currently JAX/XLA is about 100x slower than hand-written CUDA code (for first and second order derivatives) in forward mode.

mattjj commented 5 years ago

The amazing @jlebar just landed this exciting commit in XLA:GPU:

https://github.com/tensorflow/tensorflow/commit/acb84a010d1859b15fc21b64978faf20bd62956e

mattjj commented 5 years ago

The docs for CustomCall are really nice.

shoyer commented 5 years ago

I opened https://github.com/google/jax/issues/766 for the related issue of JIT for custom CPU ops.

proteneer commented 5 years ago

Just pinging this again to see how hard it would be to be able to

1) define a custom op and 2) define a custom defjvp (similar to how tensorflow can do this)

hawkinsp commented 5 years ago

You can get custom jvps right now: https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms

All the pieces for custom GPU ops are now there in principle since XLA now supports CustomCall on GPU, but we need to plug them together and expose them to users.

proteneer commented 5 years ago

To clarify:

If I were to import custom python-wrapped C++ CPU/GPU code in a function decorated with custom_transforms, is the expectation that while they won't be JITTable, I should still be able to run them normally in op-by-op mode?

proteneer commented 5 years ago

Is there a way to define jvps in a way that avoids calling the original f(x) code? I ask because I have a very expensive function where I compute the primal and tangent in one pass and would like to directly use it as opposed to going through a separate pass

@jax.custom_transforms
def f(x):
    print("calling f(x)")
    return np.sin(x ** 2)

def jvp_f(g, ans, x):
    print("calling jvp_f")
    # I have an expensive function here that also computes the primals in addition to the tangents
    return 8. * g + ans

jax.defjvp(f, jvp_f)

out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
proteneer commented 5 years ago

Nevermind - I keep posting comments where I find the right solution immediately afterwards (Sorry guys). Looks like I just needed to use the defjvp_all arg instead

@jax.custom_transforms
def f(x):
    print("calling f(x)")
    return np.sin(x ** 2)

def jvp_f(ps, ts):
    print("calling jvp_f")
    return np.sin(ps[0] ** 2), 8. * ts[0]

jax.defjvp_all(f, jvp_f)

# jax.defjvp(f, jvp_f)

out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))

print(out_primal)
# 0.4121185
print(out_tangent)
# 16.412119
jekbradbury commented 5 years ago

To clarify:

If I were to import custom python-wrapped C++ CPU/GPU code in a function decorated with custom_transforms, is the expectation that while they won't be JITTable, I should still be able to run them normally in op-by-op mode?

If you do this using your own Python wrappers and FFI, that would be the case. But if you plug into the XLA CustomCall infrastructure, there should be a way to make it work even under @jit (though it will likely require a some new code in JAX itself).

proteneer commented 5 years ago

Thank you for the clarification. I realized that I actually ended up implementing not jvps in my custom op, but rather vmap/batched jvps over the parameters (for computational efficiency purposes), so it will take a little more work for me to get this to work within the jax ecosystem.

hawkinsp commented 5 years ago

There's no documentation for this yet, but: https://github.com/google/jax/blob/master/jaxlib/cusolver.py https://github.com/google/jax/blob/master/jaxlib/cusolver.cc and https://www.tensorflow.org/xla/custom_call#custom-call_on_gpu should be most of the information you need for defining a custom GPU op because the Cusolver ops there are implemented as custom-call ops.

i.e., all the pieces are now there; all that is missing is now documentation.

proteneer commented 5 years ago

Awesome thanks - I'll take a look

sharadmv commented 2 years ago

12632 proposes a new JAX API for custom ops. Anyone interested should give us feedback!

dfm commented 3 months ago

Now that https://github.com/google/jax/pull/21925 has been merged we have a recommended/documented workflow for this: https://jax.readthedocs.io/en/latest/ffi.html