jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

Need `xla_python_gpu_callback` with `dlpack` interface #11371

Open YouJiacheng opened 2 years ago

YouJiacheng commented 2 years ago

I would like to have an enhanced version of jax.experimental.host_callback with

This will unleash infinity interoperability between JAX and other high performance computing libraries.

E.g. users can easily add a GPU kernel as a pseudo-primitive to JAX, in an inline and JIT manner with CuPy, cuda-python or openai/triton etc.

Currently add c/cpp/cuda extension to JAX need more boilerplate code(e.g. numba4jax) than Pytorch and Tensorflow, while having much less if any tutorials and references than Pytorch and Tensorflow. Yes, JAX is more flexible than Pytorch and Tensorflow, and rely less on extension. However, extensibility is still important.

At the same time: https://github.com/patrick-kidger/equinox/pull/126 could be implemented more efficiently and elegantly. @patrick-kidger https://github.com/google/jax/discussions/9813#discussioncomment-3086181 could be solved. https://github.com/google/jax/discussions/11335 could be solved in a JIT compatible manner.

YouJiacheng commented 2 years ago

I prototype something similar using xla_python_gpu_callback, but it uses numpy.ndarray instead of dlpack. Function transformation rule registration seems not very complicated to implement, but how to use dlpack is outside the scope of my current knowledge.

from functools import lru_cache

import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    print('traced')
    print(type(x))
    return x + 1

def foo_lowered(x):
    print(type(x))
    return x + 1

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))

def callback_lowering(ctx, *args, callback, callback_lowered):
    try:
        iter(abstract_eval_fun(callback, *ctx.avals_in))
    except TypeError:
        f = lambda *args: (callback_lowered(*args),)
    else:
        f = callback_lowered
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
    ctx.module_context.add_keepalive(keepalive)
    return result

mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))

def bar(x):
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
traced
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
{ lambda ; a:i32[]. let
    b:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] a
    c:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] b
    d:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] c
    e:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] d
  in (e,) }
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
5
module @jit_bar.1 {
  func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = call @callback(%arg0) : (tensor<i32>) -> tensor<i32>
    %1 = call @callback(%0) : (tensor<i32>) -> tensor<i32>
    %2 = call @callback(%1) : (tensor<i32>) -> tensor<i32>
    %3 = call @callback(%2) : (tensor<i32>) -> tensor<i32>
    return %3 : tensor<i32>
  }
  func.func private @callback(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<94158890291952> : tensor<i64>
    %1 = "mhlo.custom_call"(%0, %arg0) {api_version = 2 : i32, backend_config = "94158890291952", call_target_name = "xla_python_gpu_callback", called_computations = [], has_side_effect = false} : (tensor<i64>, tensor<i32>) -> tuple<tensor<i32>>
    %2 = "mhlo.get_tuple_element"(%1) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
    return %2 : tensor<i32>
  }
}

HloModule jit_bar.2, entry_computation_layout={(s32[])->s32[]}

ENTRY %main.26 (Arg_0.1: s32[]) -> s32[] {
  %constant_0 = s64[] constant(94158882601360)
  %Arg_0.1 = s32[] parameter(0)
  %custom-call.0 = (s32[]) custom-call(s64[] %constant_0, s32[] %Arg_0.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.0 = s32[] get-tuple-element((s32[]) %custom-call.0), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.1 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.0), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.1 = s32[] get-tuple-element((s32[]) %custom-call.1), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.2 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.2 = s32[] get-tuple-element((s32[]) %custom-call.2), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.3 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.2), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[]) %custom-call.3), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
}
hawkinsp commented 2 years ago

@sharadmv actually tried prototyping something like this already but we got a bit stuck debugging a deadlock.

Something we did not try was a dlpack API, that seems like a reasonable thing to try.

YouJiacheng commented 2 years ago

Thanks for your reply! Zero copy or at least not "d2h then h2d" is crucial(can XLA automatically donate buffer for custom call?), I think dlpack interface is very useful.

YouJiacheng commented 2 years ago

I found something unusual: xla_python_gpu_callback execution time with identity callback doesn't scale linearly w.r.t. array size. (1, 1024, 1024) float32 costs 0.20s but (8, 1024, 1024) float32 costs 4.6s. host_callback.call with outfeed implementation use 0.25s and 1.5s respectively, with xla_python_gpu_callback implementation costs 0.25s and 4.7s respectively.

sharadmv commented 2 years ago

That is unusual...could you provide us w/ the benchmark so we can repro and investigate?

YouJiacheng commented 2 years ago

Sure, full code here

from functools import lru_cache

import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    return x

def foo_lowered(x):
    return x

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))

def callback_lowering(ctx, *args, callback, callback_lowered):
    try:
        iter(abstract_eval_fun(callback, *ctx.avals_in))
    except TypeError:
        f = lambda *args: (callback_lowered(*args),)
    else:
        f = callback_lowered
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
    ctx.module_context.add_keepalive(keepalive)
    return result

mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))

@jax.jit
def bar(x):
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    return x

from contextlib import contextmanager
@contextmanager
def timer():
    from time import perf_counter
    t = perf_counter()
    yield
    print(perf_counter() - t)

x = jax.numpy.zeros((8, 1024, 1024))
bar(x).block_until_ready()
with timer():
    for _ in range(100):
        bar(x).block_until_ready()

from jax.experimental import host_callback as hcb

@jax.jit
def bar_hcb(x):
    x = hcb.call(foo_lowered, x, result_shape=x)
    return x

bar_hcb(x).block_until_ready()
with timer():
    for _ in range(100):
        bar_hcb(x).block_until_ready()