Open patrick-kidger opened 2 years ago
Have you seen the mode
argument to .at[].set()
? https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at
There is not (yet) an debugging option for raising errors, but at least you have a couple of customization options.
Thanks for the link -- indeed I didn't know about them. They don't help with resolving this issue, but they're still good to know.
+1 on more transparency as to whether it compiles to in-place or out-of-place. I suspect the massive runtime slowdown in my discussion question (https://github.com/google/jax/discussions/9116 ) might come from at[].set() triggering a copy, but i don't know how I can check or debug that at all.
I'm also having this issue and it seems that the in-place may only happen if the array is created within jit. I tried the following as a simple example and benchmarked the speeds
N=1_000
jaxa = jnp.zeros(10_000_000)
@jax.jit
def bench_jax_scan(a):
def f(i,x):
x = x.at[i].set(i)
return x
a = jax.lax.fori_loop(0,N,f,a)
return a
@jax.jit
def bench_jax_scan_2():
a = jnp.zeros(10_000_000)
def f(i,x):
x = x.at[i].set(i)
return x
a = jax.lax.fori_loop(0,N,f,a)
return a
%time bench_jax_scan(jaxa)
%timeit b=bench_jax_scan(jaxa)
CPU times: user 43.7 ms, sys: 1.8 ms, total: 45.5 ms
Wall time: 43.5 ms
4.54 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%time bench_jax_scan_2()
%timeit b=bench_jax_scan_2()
CPU times: user 44.7 ms, sys: 2.18 ms, total: 46.9 ms
Wall time: 39.1 ms
1.34 ms ± 61.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
And for reference, creating a = jnp.zeros(10_000_000)
in jit takes about 1.28ms, meaning in-place updates in jax when using an array created inside jit is about 1.34-1.28 = 0.06 ms = 60µs which is about the speed of an equivalent numpy in place op.
it seems that the in-place may only happen if the array is created within jit
This is correct. Array updates can only be inplace if performed within a JIT decorated function.
So is it not possible for an array created outside of a jit decorated function to then be passed into a jit decorated function and be updated in place?
So is it not possible for an array created outside of a jit decorated function to then be passed into a jit decorated function and be updated in place?
No - JAX arrays are immutable by design. This means that if you create a concrete buffer-backed DeviceArray object, there is no way that you can then modify its contents in-place.
I see, thanks for the clarification!
No - JAX arrays are immutable by design. This means that if you create a concrete buffer-backed DeviceArray object, there is no way that you can then modify its contents in-place.
It should be possible to perform an in-place update of an input to a jit-compiled function without violating the design if we donate the argument using the donate_argnums
argument of jax.jit
. However I was able to verify that JAX currenly doesn't update the argument in-place even if it's donated. Is this something that can potentially be implemented in JAX in future?
This is an important matter for me because implementing this feature will open some interesting possibilities. One example is a replay buffer library fully implemented in JAX that I'm currently working on (https://github.com/hr0nix/dejax). The intended use case is to have a jit-compiled train step function, which updates both the train and the replay buffer state:
train_state, replay_buffer_state = jax.jit(train_step, donate_argnums=(1,))(
train_state, replay_buffer_state, trajectory_batch
)
However if the replay buffer is large and is not updated in place, all the performance benefits of having a replay buffer inside jit-compiled code vanish.
I was able to verify that JAX currently doesn't update the argument in-place even if it's donated
Can you give more details about how you verified this?
Can you give more details about how you verified this?
Sure! I've used the same benchmarking-based approach as @StoneT2000. Here's the code:
def benchmark_update_func(init_func, update_func, num_iters):
# To pre-compile the update_func in case it uses jit
update_func_input = init_func()
update_func(update_func_input)
elapsed_time = 0.0
for _ in range(num_iters):
update_func_input = init_func()
jax.block_until_ready(update_func_input)
start_time = time.time()
output = update_func(update_func_input)
jax.block_until_ready(output)
elapsed_time += time.time() - start_time
time_per_iter = elapsed_time / num_iters
print(f'{time_per_iter * 1000.0:.2f}ms per iteration')
def make_item(n):
return jnp.full((64,), n)
def make_item_batch(n, batch_size):
return jnp.full((batch_size, 64,), n)
@pytest.mark.parametrize('buffer_size', [1_000, 10_000, 100_000, 1_000_000])
def test_modify_inplace_donate_performance(buffer_size):
def init_func():
return make_item_batch(0, buffer_size)
@partial(jax.jit, donate_argnums=(0,))
def update_func(state):
item = make_item(0)
return state.at[10].set(item)
benchmark_update_func(init_func, update_func, num_iters=100)
The output of this test (on CPU, I should say) is:
PASSED [ 25%]0.01ms per iteration
PASSED [ 50%]0.26ms per iteration
PASSED [ 75%]2.24ms per iteration
PASSED [100%]39.61ms per iteration
So there is an approximately linear dependency between the buffer size and update costs.
Buffer donation is not implemented on CPU (see https://jax.readthedocs.io/en/latest/faq.html#buffer-donation); when you execute this on CPU you should see warnings that look like
UserWarning: Some donated buffers were not usable: ShapedArray(int32[10]).
Donation is not implemented for cpu.
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
_module_unique_id = itertools.count()
Two instantiated jax arrays (such as the input and output of your test function) cannot share the same memory unless their buffers are donated. So I don't think this test tells us anything about whether updates are done in-place within JIT.
When I run your test on a Colab GPU runtime, I see the following:
%%writefile bench.py
import jax
import jax.numpy as jnp
import pytest
from functools import partial
import time
def benchmark_update_func(init_func, update_func, num_iters):
# To pre-compile the update_func in case it uses jit
update_func_input = init_func()
update_func(update_func_input)
elapsed_time = 0.0
for _ in range(num_iters):
update_func_input = init_func()
jax.block_until_ready(update_func_input)
start_time = time.time()
output = update_func(update_func_input)
jax.block_until_ready(output)
elapsed_time += time.time() - start_time
time_per_iter = elapsed_time / num_iters
print(f'{time_per_iter * 1000.0:.2f}ms per iteration')
def make_item(n):
return jnp.full((64,), n)
def make_item_batch(n, batch_size):
return jnp.full((batch_size, 64,), n)
@pytest.mark.parametrize('buffer_size', [1_000, 10_000, 100_000, 1_000_000])
def test_modify_inplace_donate_performance(buffer_size):
def init_func():
return make_item_batch(0, buffer_size)
@partial(jax.jit, donate_argnums=(0,))
def update_func(state):
item = make_item(0)
return state.at[10].set(item)
benchmark_update_func(init_func, update_func, num_iters=100)
!python -m pytest --durations=5 bench.py
============================= test session starts ==============================
platform linux -- Python 3.7.13, pytest-3.6.4, py-1.11.0, pluggy-0.7.1
rootdir: /content, inifile:
plugins: typeguard-2.7.1
collected 4 items
bench.py .... [100%]
=========================== slowest 5 test durations ===========================
1.13s call bench.py::test_modify_inplace_donate_performance[1000]
0.31s call bench.py::test_modify_inplace_donate_performance[1000000]
0.18s call bench.py::test_modify_inplace_donate_performance[100000]
0.17s call bench.py::test_modify_inplace_donate_performance[10000]
0.00s setup bench.py::test_modify_inplace_donate_performance[1000]
=========================== 4 passed in 3.04 seconds ===========================
... which looks to me like it's consistent with updates being in-place when buffer donation is available. What do you think?
Buffer donation is not implemented on CPU (see https://jax.readthedocs.io/en/latest/faq.html#buffer-donation)
Wow, good catch, I did not know that. And it seems that pytest muted this useful warning.
I've just run this code on GPU and can confirm that updates with donation are indeed in-place.
I have a program in which XLA is treating a
.at[].set()
update out-of-place rather than in-place. At the moment it's actually not clear to me whether this is a bug in XLA*
, or a fault in my program that is preventing the optimisation.What I'd really like is either:
AFAIK there's no way to do this in JAX at the moment. In my head I'm imagining something like an environment variable similar to the
JAX_DEBUG_NANS
one used to catch NaNs.(More generally I would remark that ways to properly introspect/debug/understand the compiled XLA would be really nice.)
*
I know of at least one example of this being the case for XLA:CPU (#8192), but I don't think I'm running into that particular bug here.