jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

Feature request: disabling/detecting out-of-place `.at[].set()` updates. #9132

Open patrick-kidger opened 2 years ago

patrick-kidger commented 2 years ago

I have a program in which XLA is treating a .at[].set() update out-of-place rather than in-place. At the moment it's actually not clear to me whether this is a bug in XLA*, or a fault in my program that is preventing the optimisation.

What I'd really like is either:

AFAIK there's no way to do this in JAX at the moment. In my head I'm imagining something like an environment variable similar to the JAX_DEBUG_NANS one used to catch NaNs.

(More generally I would remark that ways to properly introspect/debug/understand the compiled XLA would be really nice.)

* I know of at least one example of this being the case for XLA:CPU (#8192), but I don't think I'm running into that particular bug here.

shoyer commented 2 years ago

Have you seen the mode argument to .at[].set()? https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at

There is not (yet) an debugging option for raising errors, but at least you have a couple of customization options.

patrick-kidger commented 2 years ago

Thanks for the link -- indeed I didn't know about them. They don't help with resolving this issue, but they're still good to know.

phc27x commented 2 years ago

+1 on more transparency as to whether it compiles to in-place or out-of-place. I suspect the massive runtime slowdown in my discussion question (https://github.com/google/jax/discussions/9116 ) might come from at[].set() triggering a copy, but i don't know how I can check or debug that at all.

StoneT2000 commented 2 years ago

I'm also having this issue and it seems that the in-place may only happen if the array is created within jit. I tried the following as a simple example and benchmarked the speeds

N=1_000
jaxa = jnp.zeros(10_000_000)
@jax.jit
def bench_jax_scan(a):
    def f(i,x):
        x = x.at[i].set(i)
        return x
    a = jax.lax.fori_loop(0,N,f,a)
    return a
@jax.jit
def bench_jax_scan_2():
    a = jnp.zeros(10_000_000)
    def f(i,x):
        x = x.at[i].set(i)
        return x
    a = jax.lax.fori_loop(0,N,f,a)
    return a
%time bench_jax_scan(jaxa)
%timeit b=bench_jax_scan(jaxa)
CPU times: user 43.7 ms, sys: 1.8 ms, total: 45.5 ms
Wall time: 43.5 ms
4.54 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%time bench_jax_scan_2()
%timeit b=bench_jax_scan_2()
CPU times: user 44.7 ms, sys: 2.18 ms, total: 46.9 ms
Wall time: 39.1 ms
1.34 ms ± 61.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

And for reference, creating a = jnp.zeros(10_000_000) in jit takes about 1.28ms, meaning in-place updates in jax when using an array created inside jit is about 1.34-1.28 = 0.06 ms = 60µs which is about the speed of an equivalent numpy in place op.

shoyer commented 2 years ago

it seems that the in-place may only happen if the array is created within jit

This is correct. Array updates can only be inplace if performed within a JIT decorated function.

StoneT2000 commented 2 years ago

So is it not possible for an array created outside of a jit decorated function to then be passed into a jit decorated function and be updated in place?

jakevdp commented 2 years ago

So is it not possible for an array created outside of a jit decorated function to then be passed into a jit decorated function and be updated in place?

No - JAX arrays are immutable by design. This means that if you create a concrete buffer-backed DeviceArray object, there is no way that you can then modify its contents in-place.

StoneT2000 commented 2 years ago

I see, thanks for the clarification!

hr0nix commented 2 years ago

No - JAX arrays are immutable by design. This means that if you create a concrete buffer-backed DeviceArray object, there is no way that you can then modify its contents in-place.

It should be possible to perform an in-place update of an input to a jit-compiled function without violating the design if we donate the argument using the donate_argnums argument of jax.jit. However I was able to verify that JAX currenly doesn't update the argument in-place even if it's donated. Is this something that can potentially be implemented in JAX in future?

This is an important matter for me because implementing this feature will open some interesting possibilities. One example is a replay buffer library fully implemented in JAX that I'm currently working on (https://github.com/hr0nix/dejax). The intended use case is to have a jit-compiled train step function, which updates both the train and the replay buffer state:

train_state, replay_buffer_state = jax.jit(train_step, donate_argnums=(1,))(
    train_state, replay_buffer_state, trajectory_batch
)

However if the replay buffer is large and is not updated in place, all the performance benefits of having a replay buffer inside jit-compiled code vanish.

jakevdp commented 2 years ago

I was able to verify that JAX currently doesn't update the argument in-place even if it's donated

Can you give more details about how you verified this?

hr0nix commented 2 years ago

Can you give more details about how you verified this?

Sure! I've used the same benchmarking-based approach as @StoneT2000. Here's the code:

def benchmark_update_func(init_func, update_func, num_iters):
    # To pre-compile the update_func in case it uses jit
    update_func_input = init_func()
    update_func(update_func_input)

    elapsed_time = 0.0
    for _ in range(num_iters):
        update_func_input = init_func()
        jax.block_until_ready(update_func_input)

        start_time = time.time()
        output = update_func(update_func_input)
        jax.block_until_ready(output)
        elapsed_time += time.time() - start_time

    time_per_iter = elapsed_time / num_iters
    print(f'{time_per_iter * 1000.0:.2f}ms per iteration')

def make_item(n):
    return jnp.full((64,), n)

def make_item_batch(n, batch_size):
    return jnp.full((batch_size, 64,), n)

@pytest.mark.parametrize('buffer_size', [1_000, 10_000, 100_000, 1_000_000])
def test_modify_inplace_donate_performance(buffer_size):
    def init_func():
        return make_item_batch(0, buffer_size)

    @partial(jax.jit, donate_argnums=(0,))
    def update_func(state):
        item = make_item(0)
        return state.at[10].set(item)

    benchmark_update_func(init_func, update_func, num_iters=100)

The output of this test (on CPU, I should say) is:

PASSED [ 25%]0.01ms per iteration
PASSED [ 50%]0.26ms per iteration
PASSED [ 75%]2.24ms per iteration
PASSED [100%]39.61ms per iteration

So there is an approximately linear dependency between the buffer size and update costs.

jakevdp commented 2 years ago

Buffer donation is not implemented on CPU (see https://jax.readthedocs.io/en/latest/faq.html#buffer-donation); when you execute this on CPU you should see warnings that look like

UserWarning: Some donated buffers were not usable: ShapedArray(int32[10]).
Donation is not implemented for cpu.
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  _module_unique_id = itertools.count()

Two instantiated jax arrays (such as the input and output of your test function) cannot share the same memory unless their buffers are donated. So I don't think this test tells us anything about whether updates are done in-place within JIT.

jakevdp commented 2 years ago

When I run your test on a Colab GPU runtime, I see the following:

%%writefile bench.py
import jax
import jax.numpy as jnp
import pytest
from functools import partial
import time

def benchmark_update_func(init_func, update_func, num_iters):
    # To pre-compile the update_func in case it uses jit
    update_func_input = init_func()
    update_func(update_func_input)

    elapsed_time = 0.0
    for _ in range(num_iters):
        update_func_input = init_func()
        jax.block_until_ready(update_func_input)

        start_time = time.time()
        output = update_func(update_func_input)
        jax.block_until_ready(output)
        elapsed_time += time.time() - start_time

    time_per_iter = elapsed_time / num_iters
    print(f'{time_per_iter * 1000.0:.2f}ms per iteration')

def make_item(n):
    return jnp.full((64,), n)

def make_item_batch(n, batch_size):
    return jnp.full((batch_size, 64,), n)

@pytest.mark.parametrize('buffer_size', [1_000, 10_000, 100_000, 1_000_000])
def test_modify_inplace_donate_performance(buffer_size):
    def init_func():
        return make_item_batch(0, buffer_size)

    @partial(jax.jit, donate_argnums=(0,))
    def update_func(state):
        item = make_item(0)
        return state.at[10].set(item)

    benchmark_update_func(init_func, update_func, num_iters=100)
!python -m pytest --durations=5 bench.py
============================= test session starts ==============================
platform linux -- Python 3.7.13, pytest-3.6.4, py-1.11.0, pluggy-0.7.1
rootdir: /content, inifile:
plugins: typeguard-2.7.1
collected 4 items                                                              

bench.py ....                                                            [100%]

=========================== slowest 5 test durations ===========================
1.13s call     bench.py::test_modify_inplace_donate_performance[1000]
0.31s call     bench.py::test_modify_inplace_donate_performance[1000000]
0.18s call     bench.py::test_modify_inplace_donate_performance[100000]
0.17s call     bench.py::test_modify_inplace_donate_performance[10000]
0.00s setup    bench.py::test_modify_inplace_donate_performance[1000]
=========================== 4 passed in 3.04 seconds ===========================

... which looks to me like it's consistent with updates being in-place when buffer donation is available. What do you think?

hr0nix commented 2 years ago

Buffer donation is not implemented on CPU (see https://jax.readthedocs.io/en/latest/faq.html#buffer-donation)

Wow, good catch, I did not know that. And it seems that pytest muted this useful warning.

I've just run this code on GPU and can confirm that updates with donation are indeed in-place.