Open YouJiacheng opened 2 years ago
I prototype something similar using xla_python_gpu_callback
, but it uses numpy.ndarray
instead of dlpack
.
Function transformation rule registration seems not very complicated to implement, but how to use dlpack
is outside the scope of my current knowledge.
from functools import lru_cache
import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
print('traced')
print(type(x))
return x + 1
def foo_lowered(x):
print(type(x))
return x + 1
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))
def callback_lowering(ctx, *args, callback, callback_lowered):
try:
iter(abstract_eval_fun(callback, *ctx.avals_in))
except TypeError:
f = lambda *args: (callback_lowered(*args),)
else:
f = callback_lowered
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))
def bar(x):
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
traced
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
{ lambda ; a:i32[]. let
b:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] a
c:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] b
d:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] c
e:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] d
in (e,) }
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
5
module @jit_bar.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = call @callback(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = call @callback(%0) : (tensor<i32>) -> tensor<i32>
%2 = call @callback(%1) : (tensor<i32>) -> tensor<i32>
%3 = call @callback(%2) : (tensor<i32>) -> tensor<i32>
return %3 : tensor<i32>
}
func.func private @callback(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<94158890291952> : tensor<i64>
%1 = "mhlo.custom_call"(%0, %arg0) {api_version = 2 : i32, backend_config = "94158890291952", call_target_name = "xla_python_gpu_callback", called_computations = [], has_side_effect = false} : (tensor<i64>, tensor<i32>) -> tuple<tensor<i32>>
%2 = "mhlo.get_tuple_element"(%1) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %2 : tensor<i32>
}
}
HloModule jit_bar.2, entry_computation_layout={(s32[])->s32[]}
ENTRY %main.26 (Arg_0.1: s32[]) -> s32[] {
%constant_0 = s64[] constant(94158882601360)
%Arg_0.1 = s32[] parameter(0)
%custom-call.0 = (s32[]) custom-call(s64[] %constant_0, s32[] %Arg_0.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.0 = s32[] get-tuple-element((s32[]) %custom-call.0), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.1 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.0), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.1 = s32[] get-tuple-element((s32[]) %custom-call.1), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.2 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.2 = s32[] get-tuple-element((s32[]) %custom-call.2), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.3 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.2), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[]) %custom-call.3), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
}
@sharadmv actually tried prototyping something like this already but we got a bit stuck debugging a deadlock.
Something we did not try was a dlpack
API, that seems like a reasonable thing to try.
Thanks for your reply!
Zero copy or at least not "d2h then h2d" is crucial(can XLA automatically donate buffer for custom call?), I think dlpack
interface is very useful.
I found something unusual: xla_python_gpu_callback
execution time with identity callback doesn't scale linearly w.r.t. array size.
(1, 1024, 1024) float32
costs 0.20s but (8, 1024, 1024) float32
costs 4.6s.
host_callback.call
with outfeed implementation use 0.25s and 1.5s respectively, with xla_python_gpu_callback
implementation costs 0.25s and 4.7s respectively.
That is unusual...could you provide us w/ the benchmark so we can repro and investigate?
Sure, full code here
from functools import lru_cache
import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
return x
def foo_lowered(x):
return x
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))
def callback_lowering(ctx, *args, callback, callback_lowered):
try:
iter(abstract_eval_fun(callback, *ctx.avals_in))
except TypeError:
f = lambda *args: (callback_lowered(*args),)
else:
f = callback_lowered
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))
@jax.jit
def bar(x):
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
return x
from contextlib import contextmanager
@contextmanager
def timer():
from time import perf_counter
t = perf_counter()
yield
print(perf_counter() - t)
x = jax.numpy.zeros((8, 1024, 1024))
bar(x).block_until_ready()
with timer():
for _ in range(100):
bar(x).block_until_ready()
from jax.experimental import host_callback as hcb
@jax.jit
def bar_hcb(x):
x = hcb.call(foo_lowered, x, result_shape=x)
return x
bar_hcb(x).block_until_ready()
with timer():
for _ in range(100):
bar_hcb(x).block_until_ready()
I would like to have an enhanced version of
jax.experimental.host_callback
withdlpack
interface(zero-copy if possible, both CPU and GPU)ad
andbatching
especially) with minimal boilerplate codeThis will unleash infinity interoperability between JAX and other high performance computing libraries.
E.g. users can easily add a GPU kernel as a pseudo-primitive to JAX, in an inline and JIT manner with CuPy, cuda-python or openai/triton etc.
Currently add c/cpp/cuda extension to JAX need more boilerplate code(e.g. numba4jax) than Pytorch and Tensorflow, while having much less if any tutorials and references than Pytorch and Tensorflow. Yes, JAX is more flexible than Pytorch and Tensorflow, and rely less on extension. However, extensibility is still important.
At the same time: https://github.com/patrick-kidger/equinox/pull/126 could be implemented more efficiently and elegantly. @patrick-kidger https://github.com/google/jax/discussions/9813#discussioncomment-3086181 could be solved. https://github.com/google/jax/discussions/11335 could be solved in a JIT compatible manner.