jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.52k stars 2.8k forks source link

Grad is very slow in combination with slice update + scan #17640

Open karunrao97 opened 1 year ago

karunrao97 commented 1 year ago

Description

Hi,

I am running a scan over several thousand steps with a 2D array as the carry, where each step consists of (among other operations):

I know the maximum number of columns (let's call it N_COLS) ahead of time, so we can do the above operation by allocating a 2D array with width twice N_COLS, and always shifting a slice of size N_COLS left and right for the delete and insert, respectively. This works well for the forward pass, but I observe very slow gradient computations (a factor of 100x slower). I think the gradient computation is triggering a copy of the entire carry instead of updating it in-place, and I came up with a fairly strange work-around (called "combo2" in the below MWE) that seems to update in-place and achieve reasonable performance, but only when using JVP. I would like to know if there is a better way of implementing this, and also whether there is a way to achieve the same using VJP.

Here is a minimal working example:

import timeit

from functools import partial

import jax
import numpy as np

from jax import lax
from jax import numpy as jnp

N_STEPS = 1_000
N_ROWS, N_COLS = 100, 500
N_ROW_REPEATS = 10

@partial(jax.jit, static_argnames="mode", donate_argnames="state")
def insert_fn(state, row, col, value, mode):
    if mode == "slice1":
        update = jnp.concatenate([jnp.array([[value]]), lax.dynamic_slice(state, start_indices=(row, col), slice_sizes=(1, N_COLS))], axis=-1)
        return lax.dynamic_update_slice(state, update=update, start_indices=(row, col))
    if mode == "slice2":
        update = lax.dynamic_slice(state, start_indices=(row, col), slice_sizes=(1, N_COLS))
        state = lax.dynamic_update_slice(state, update=update, start_indices=(row, col + 1))
        return state.at[row, col].set(value)
    elif mode == "scatter":
        indices = col + np.arange(N_COLS + 1, dtype=np.int32)
        values = jnp.r_[value, lax.dynamic_slice(state, start_indices=(row, col), slice_sizes=(1, N_COLS)).ravel()]
        return state.at[row, indices].set(values, indices_are_sorted=True, unique_indices=True)
    elif mode == "combo1":
        update = lax.dynamic_slice(state, start_indices=(row, col), slice_sizes=(1, N_COLS))
        state = lax.dynamic_update_slice(state, update=update, start_indices=(row, col + 1))
        indices = col + np.arange(1, dtype=np.int32)
        values = jnp.array([value])
        return state.at[row, indices].set(values, indices_are_sorted=True, unique_indices=True)
    elif mode == "combo2":
        update = lax.dynamic_slice(state, start_indices=(row, col + 1), slice_sizes=(1, N_COLS - 1))
        state = lax.dynamic_update_slice(state, update=update, start_indices=(row, col + 2))
        indices = col + np.arange(2, dtype=np.int32)
        values = jnp.array([value, state[row, col]])
        return state.at[row, indices].set(values, indices_are_sorted=True, unique_indices=True)
    else:
        raise ValueError(f"unknown {mode=}")

@partial(jax.jit, donate_argnames="state")
def delete_fn(state, row, col):
    update = lax.dynamic_slice(state, start_indices=(row, col + 1), slice_sizes=(1, N_COLS))
    return lax.dynamic_update_slice(state, update=update, start_indices=(row, col))

@partial(jax.jit, static_argnames="mode", donate_argnames="state")
def update_fn(state, inputs, mode):
    (insert_row, delete_row), (insert_col, delete_col), value = inputs
    state = insert_fn(state, insert_row, insert_col, value, mode)
    state = delete_fn(state, delete_row, delete_col)
    return state, None

@partial(jax.jit, static_argnames="mode")
def loss_fn(w, mode, init_state, target_state, row_indices, col_indices, insert_values):
    state, _ = lax.scan(partial(update_fn, mode=mode), init_state, (row_indices, col_indices, insert_values**w))
    return ((target_state - state[:N_ROWS, :N_COLS]) ** 2).mean()

if __name__ == "__main__":
    np.random.seed(42)
    init_state = jnp.zeros(shape=(N_ROWS * N_ROW_REPEATS, N_COLS * 2), dtype=jnp.float32)
    loss_kwargs = dict(
        target_state=jnp.asarray(np.random.random(size=(N_ROWS, N_COLS)).astype(np.float32)),
        row_indices=jnp.asarray(np.random.randint(low=0, high=N_ROWS, size=(N_STEPS, 2), dtype=np.int32)),
        col_indices=jnp.asarray(np.random.randint(low=0, high=N_COLS, size=(N_STEPS, 2), dtype=np.int32)),
        insert_values=jnp.asarray(np.random.random(size=N_STEPS).astype(np.float32)),
    )

    modes = "slice1", "slice2", "scatter", "combo1", "combo2"
    loss_fns = {k: jax.jit(partial(loss_fn, mode=k, init_state=init_state.copy(), **loss_kwargs)) for k in modes}
    vjp_fns = {k: jax.jit(jax.grad(v)) for k, v in loss_fns.items()}
    jvp_fns = {k: jax.jit(partial(jax.jvp, fun=v, tangents=[jnp.float32(1)])) for k, v in loss_fns.items()}

    # Run all functions once to JIT-compile
    g = None
    w = jnp.float32(2)
    for k in loss_fns.keys():
        loss_fns[k](w).block_until_ready()
        g1 = vjp_fns[k](w).block_until_ready()
        g2 = jvp_fns[k](primals=[w])[1].block_until_ready()
        if g is None:
            g = g1
            print(f"{g=}")
        assert np.isclose(g, g1)
        assert np.isclose(g, g2)

    # Run benchmarks
    n = 10
    for k in loss_fns.keys():
        print()
        print(k)
        print(f"  loss: {timeit.timeit(lambda: loss_fns[k](w).block_until_ready(), number=n) / n : .5f}")
        print(f"   vjp: {timeit.timeit(lambda: vjp_fns[k](w).block_until_ready(), number=n) / n : .5f}")
        print(f"   jvp: {timeit.timeit(lambda: jvp_fns[k](primals=[w])[1].block_until_ready(), number=n) / n : .5f}")

Here is the output with N_ROW_REPEATS = 1:

g=Array(0.00051642, dtype=float32)

slice1
  loss:  0.00046
   vjp:  0.05405
   jvp:  0.05525

slice2
  loss:  0.00042
   vjp:  0.08886
   jvp:  0.05200

scatter
  loss:  0.00592
   vjp:  0.06749
   jvp:  0.01205

combo1
  loss:  0.00041
   vjp:  0.08841
   jvp:  0.05208

combo2
  loss:  0.00042
   vjp:  0.08825
   jvp:  0.00082

And here is the output with N_ROW_REPEATS = 10:

g=Array(0.00051642, dtype=float32)

slice1
  loss:  0.00060
   vjp:  1.08682
   jvp:  1.12666

slice2
  loss:  0.00054
   vjp:  1.62256
   jvp:  1.20784

scatter
  loss:  0.00603
   vjp:  1.20957
   jvp:  0.01141

combo1
  loss:  0.00048
   vjp:  2.23781
   jvp:  1.43395

combo2
  loss:  0.00052
   vjp:  1.87856
   jvp:  0.00160

I would have expected one of the two "slice" modes to be most efficient, and this does seem to be the case for the forward pass. However, given the 100x slower gradient computation (using both VJP and JVP), I attempted some other methods. The "scatter" mode uses advanced-indexing with the .at[...].set(...) notation, which is slower on the forward pass (as expected), but the gradient computation is much faster (when using JVP) than the "slice" modes. The faster gradient compute relies on unique_indices=True, and without that, it is again very slow -- seems we might be benefitting from this code path: https://github.com/google/jax/blob/bcc545a69232e983ae31b0395f4972979f2789c0/jax/_src/lax/slicing.py#L2297

The "combo" modes utilize the above insight to combine a small scatter with a dynamic slice update, but this seems to only work well in the case ("combo2") that the scatter update contains more than one element -- I guess "combo1" reduces to to a slice update instead of a scatter.

As seen in the case where N_ROW_REPEATS = 10, as we enlarge the carry, all modes scale terribly with VJP, and all modes except "scatter" and "combo2" also scale terribly with JVP, even though we're only slicing into a fraction of the rows in carry (and the slices are identical to those with N_ROW_REPEATS = 1) -- this suggests that these modes are triggering a copy of the entire carry, and not updating it in-place as intended.

Some questions:

I'm new to JAX, so apologies if I'm missing something obvious. Thank you for your help!

What jax/jaxlib version are you using?

jax v0.4.14, jaxlib v0.4.14

Which accelerator(s) are you using?

CPU

Additional system info

python v3.9.9, linux

NVIDIA GPU info

No response

patrick-kidger commented 1 year ago

If I'm reading this correctly (your example is quite large!), I think this is another example of #10197. This is due to a bug in the XLA compiler, which indeed triggers spurious copies when doing a grad(scan(slice + in-place-update)). Can you check if this workaround works for you?

karunrao97 commented 1 year ago

Thank you for your reply!

I tried replacing lax.scan with equinox.internal.scan as follows:

import equinox.internal as eqxi
...

state, _ = eqxi.scan(
    partial(update_fn, mode=mode),
    init_state,
    (row_indices, col_indices, insert_values**w),
    buffers=lambda x: x,
    kind="checkpointed",
    checkpoints="all",
)

But it seems that eqxi.scan is incompatible with lax.dynamic_slice, and I get the following stack trace:

    update = lax.dynamic_slice(state, start_indices=(row, col + 1), slice_sizes=(1, N_COLS - 1))
AttributeError: '_Buffer' object has no attribute 'ndim'

I cannot use static slices because that blows up compile time to several hours in my case. I also need to do multiple operations including matmuls on slices of the buffer, not only .at[...].set(...), which is only here for the MWE.

Is there a key insight I can incorporate from eqxi.scan into my code? I followed the eqxi.scan code to https://github.com/patrick-kidger/equinox/blob/76a13bd100a8127c7e5176f2dee9d36c837b46ec/equinox/internal/_loop/common.py#L220, at which point I'm a bit lost and probably need to read up more on JAX internals. I don't need a while loop, and I don't think I need to prevent cond turning into select on vmap. Is there something simpler that might work for my use case?

Thanks again!

patrick-kidger commented 1 year ago

Can you try doing lax.dynamic_slice(state._array, ...)? (Here, state is the buffer object.) That's private API right now, but we can add a public API for dynamic slicing if it'd help. Though as noted, the Equinox workaround for this XLA bug only works if you do precisely one read, whether that is via state[i] or dynamic_slice(state._array, ...) -- if you have multiple reads in your actual use case, then you'll need to rewrite things to do them all in one go.

As for what eqxi.scan is doing under the hood -- it actually combines several fixes for various JAX/XLA issues! (You've spotted the vmap(cond) one, which is another one.)

To explain the important part of what's going on here. When you have the program

y = xs[i]
xs = xs.at[j].set(y)

then on the backward pass this becomes

grad_y = grad_xs[j]
grad_xs = grads_xs.at[j].set(0)
grad_xs = grad_xs.at[i].set(grad_y)

and unfortunately, XLA thinks that the .at[].set(0) overwrites the memory in which grad_y is held, so it erroneously makes a copy of all of grad_xs. This bug triggers whenever you have a scan containing >=1 read and >=2 writes, all from the same buffer.

As such, the key insight in the Equinox code is to create a custom primitive (the one you've linked to) to replace the xs.at[j].set(y) call, which when backpropagated skips the .at[].set(0). In doing so we get down to just a single write on the backward pass, and we no longer hit the bug. In terms of program correctness, this will be valid provided we never again read from grad_xs[j] on the backward pass, which will itself be valid provided we never write to xs[j] twice on the forward pass. Thus the requirement listed at the end of the workaround above.

I appreciate this is a fairly complicated problem to try and work around. :/

karunrao97 commented 1 year ago

Thank you very much for the detailed insights.

I tried using state._array, but it throws the following error:

AttributeError: DynamicJaxprTracer has no attribute _array

I inspected state, and it is indeed of type Traced<ShapedArray(float32[100,1000])>with<DynamicJaxprTrace(level=5/0)>, so it seems the buffer wrapping happens after this?

Nevertheless, reading through the equinox code that circumvents this issue was very helpful. I managed to write my own custom primitive for my insert operation using similar ideas, and I think I see the problem now. Do I understand correctly that the work-around using eqxi.scan buffers only works if each element of the buffer is written to only once throughout the entire scan? That is not the case for me -- I have unique index writes for each step of the scan, but across the scan there will be overlap.

I reduced the original MWE to something simpler below, which includes a working solution (fast_insert) even in the case where there are overlaps in write indices across steps. However, this doesn't work for my actual use case, which includes more writes per step -- still unique per step, but I cannot do them in a single slice update. I believe the bad_insert function does something similar to the buffers in eqxi.scan, and it works for a small number of N_STEPS while there are no repeated rows, but as soon as a row is repeated, it produces incorrect gradients. Both slow_insert and fast_insert work correctly, though the latter seems to be able to avoid copying the carry. I don't quite understand why slow_insert cannot update the carry in-place. This looks like a bug?

Note that while I'm using custom_vjp below, implementing a primitive (similar to how it's done in equinox) yields the same results, so I omitted that for the sake of brevity.

import timeit

from functools import partial

import jax
import numpy as np

from jax import lax
from jax import numpy as jnp

N_STEPS = 1_000
N_ROWS, N_COLS = 100, 500

def simple_insert(carry, inputs):
    row, col, value = inputs
    update = jnp.concatenate([jnp.array([[value]]), lax.dynamic_slice(carry, start_indices=(row, col), slice_sizes=(1, N_COLS))], axis=-1)
    return lax.dynamic_update_slice(carry, update=update, start_indices=(row, col))

@jax.custom_vjp
def bad_insert(carry, inputs):
    return simple_insert(carry, inputs)

def _bad_insert_fwd(carry, inputs):
    row, col, _ = inputs
    return bad_insert(carry, inputs), (row, col)

def _bad_insert_bwd(residuals, ct_out):
    row, col = residuals
    return ct_out, (0, 0, ct_out[row, col])  # Using 0 because None throws "ValueError: safe_zip() argument 2 is shorter than argument 1" in _flatten_bwd

@jax.custom_vjp
def slow_insert(carry, inputs):
    return simple_insert(carry, inputs)

def _slow_insert_fwd(carry, inputs):
    row, col, _ = inputs
    return slow_insert(carry, inputs), (row, col)

def _slow_insert_bwd(residuals, ct_out):
    row, col = residuals
    ct_value = ct_out[row, col].copy()
    update = lax.dynamic_slice(ct_out, start_indices=(row, col + 1), slice_sizes=(1, N_COLS))
    ct_out = lax.dynamic_update_slice(ct_out, update=update, start_indices=(row, col))
    ct_out = ct_out.at[row, col + N_COLS].set(0.0)
    return ct_out, (0, 0, ct_value)

@jax.custom_vjp
def fast_insert(carry, inputs):
    return simple_insert(carry, inputs)

def _fast_insert_fwd(carry, inputs):
    row, col, _ = inputs
    return fast_insert(carry, inputs), (row, col)

def _fast_insert_bwd(residuals, ct_out):
    row, col = residuals
    ct_out = ct_out.at[row, -1].set(ct_out[row, col])
    update = lax.dynamic_slice(ct_out, start_indices=(row, col + 1), slice_sizes=(1, N_COLS))
    ct_out = lax.dynamic_update_slice(ct_out, update=update, start_indices=(row, col))
    ct_out = ct_out.at[row, col + N_COLS].set(0.0)
    return ct_out, (0, 0, ct_out[row, -1])

bad_insert.defvjp(_bad_insert_fwd, _bad_insert_bwd)
slow_insert.defvjp(_slow_insert_fwd, _slow_insert_bwd)
fast_insert.defvjp(_fast_insert_fwd, _fast_insert_bwd)

@partial(jax.jit, static_argnames="fn")
def wrap(fn, *args):
    return fn(*args), None

@partial(jax.jit, static_argnames="fn")
def loss(w, fn, init_state, target_state, rows, cols, values):
    state, _ = lax.scan(partial(wrap, fn), init_state, (rows, cols, values**w))
    return ((target_state - state[:, :N_COLS]) ** 2).mean()

if __name__ == "__main__":
    np.random.seed(42)
    init_state = jnp.zeros(shape=(N_ROWS, N_COLS * 2 + 1), dtype=jnp.float32)
    loss_kwargs = dict(
        target_state=jnp.asarray(np.random.random(size=(N_ROWS, N_COLS)).astype(np.float32)),
        rows=jnp.asarray(np.random.randint(low=0, high=N_ROWS, size=N_STEPS, dtype=np.int32)),
        cols=jnp.asarray(np.random.randint(low=0, high=N_COLS, size=N_STEPS, dtype=np.int32)),
        values=jnp.asarray(np.random.random(size=N_STEPS).astype(np.float32)),
    )

    fns = simple_insert, bad_insert, slow_insert, fast_insert
    loss_fns = {fn.__name__: jax.jit(partial(loss, fn=fn, init_state=init_state.copy(), **loss_kwargs)) for fn in fns}
    grad_fns = {k: jax.jit(jax.grad(v)) for k, v in loss_fns.items()}

    # Run all functions once to JIT-compile
    g_true = None
    w = jnp.float32(2)
    for k in loss_fns.keys():
        loss_fns[k](w).block_until_ready()
        g = grad_fns[k](w).block_until_ready()
        if g_true is None:
            g_true = g
            print(f"{g_true=}")
        elif not np.isclose(g, g_true):
            print(f"{k} is incorrect, {g=}")
        else:
            print(f"{k} is correct")

    # Run benchmarks
    n = 10
    for k in loss_fns.keys():
        print()
        print(k)
        print(f"  loss: {timeit.timeit(lambda: loss_fns[k](w).block_until_ready(), number=n) / n : .5f}")
        print(f"  grad: {timeit.timeit(lambda: grad_fns[k](w).block_until_ready(), number=n) / n : .5f}")

Running it produces the following output:

g_true=Array(0.00052476, dtype=float32)
bad_insert is incorrect, g=Array(0.0017259, dtype=float32)
slow_insert is correct
fast_insert is correct

simple_insert
  loss:  0.00029
  grad:  0.02732

bad_insert
  loss:  0.00029
  grad:  0.00030

slow_insert
  loss:  0.00029
  grad:  0.05066

fast_insert
  loss:  0.00030
  grad:  0.00044
karunrao97 commented 1 year ago

Looking at the slow_insert example from the previous post, I believe this is not even an issue with grad, but just with scan + slice update.

Here's a (hopefully) simpler MWE to illustrate the issue without grad:

import timeit

from functools import partial

import jax
import numpy as np

from jax import lax
from jax import numpy as jnp

N_STEPS = 1_000
N_ROWS, N_COLS = 100, 500

@partial(jax.jit, donate_argnames="carry")
def slow_pop(carry, inputs):
    row, col = inputs
    value = carry[row, col].copy()
    update = lax.dynamic_slice(carry, (row, col + 1), (1, N_COLS))
    carry = lax.dynamic_update_slice(carry, update, (row, col))
    return carry, value

@partial(jax.jit, donate_argnames="carry")
def fast_pop(carry, inputs):
    row, col = inputs
    carry = carry.at[row, -1].set(carry[row, col])
    update = lax.dynamic_slice(carry, (row, col + 1), (1, N_COLS))
    carry = lax.dynamic_update_slice(carry, update, (row, col))
    return carry, carry[row, -1]

def scan_fn(fn, init_state, rows, cols):
    state, out = lax.scan(fn, init_state.copy(), (rows, cols))
    state.block_until_ready()
    out.block_until_ready()
    return state, out

if __name__ == "__main__":
    np.random.seed(42)
    init_state = jnp.asarray(np.random.random((N_ROWS, N_COLS * 2 + 1)).astype(jnp.float32))
    rows = jnp.asarray(np.random.randint(0, N_ROWS, size=N_STEPS, dtype=np.int32))
    cols = jnp.asarray(np.random.randint(0, N_COLS, size=N_STEPS, dtype=np.int32))

    # Run both functions once to JIT-compile
    state1, out1 = scan_fn(slow_pop, init_state, rows, cols)
    state2, out2 = scan_fn(fast_pop, init_state, rows, cols)
    assert (out1 == out2).all()
    assert (state1[:, :-1] == state2[:, :-1]).all()

    # Run benchmarks
    n = 10
    for fn in slow_pop, fast_pop:
        print(f"{fn.__name__}: {timeit.timeit(partial(scan_fn, fn, init_state, rows, cols), number=n) / n : .5f}")

And the output:

slow_pop:  0.05314
fast_pop:  0.00086

slow_pop does the intuitive thing -- copy the element at (row, col) into a separate value, then shift remaining elements in col left, return both the new carry and the value. fast_pop takes a more convoluted approach -- it first writes the value at (row, col) into (row, -1), where the last index is reserved for this operation, then does the shift left and finally returns the new carry and the element at carry[row, -1] as the value. It seems like the key here is to only return elements of the last state of carry, instead of extracting some values from an earlier state of carry, which seems to disrupt the in-place optimizations.

Is this expected behaviour? Also, is it perhaps better to close this issue and open a separate one? The title of this one might be misleading.

patrick-kidger commented 1 year ago

Re: error with state._array

It looks like you're not setting eqxi.scan(..., buffers=...) to return this variable. (Although I think you did do so in https://github.com/google/jax/issues/17640#issuecomment-1725025108, so I think your code must have changed since then.)

Re: your use case

I have unique index writes for each step of the scan, but across the scan there will be overlap.

That is, you're overwriting information. In this case there's nothing anyone can do: you're trying to delete information that's needed for the backward pass. In this case JAX/XLA is doing the correct thing by making a copy -- it's silently giving you correctness rather than speed.

This is likely the reason for your statement "but as soon as a row is repeated, it produces incorrect gradients".

Re: fast_pop vs slow_pop

Okay, this is very interesting! Indeed, doing the read after all of the writes seems to avoid the XLA bug. So copying the information you need into another part of the buffer is a nice trick.

Indeed I would suggest opening a new issue for this particular thing, with just the simpler MWE.

karunrao97 commented 1 year ago

Thanks again for your reply!

Re: eqxi.scan I think I am using it correctly, but perhaps trying to access ._array fails at trace time? image As noted above, though, unfortunately this solution doesn't seem applicable to my use case anyway.

Re: my use case I agree that JAX will need to save information that's overwritten. But each step here shifts a slice right by one index, and thus overwrites a single element immediately after the slice. I expected JAX to copy the tangents at that single index alone and shift the other tangents right, but it seems to be copying the entire carry (and/or perhaps the tangents for the entire carry). In fact, I'm pretty sure it is actually shifting the tangents under the hood, because my fast_insert implementation in the second example above does not actually save or shift any tangents in the forward pass, while it does shift the cotangents left in the backward pass in order to obtain the same gradients as simple_insert, although with much better performance. I used the same trick there as in fast_pop above, and in fact fast_pop is essentially the same as the backward pass of fast_insert in the previous example. Note that in my particular use case, I can get away with not saving any tangents in the forward pass because the element after the slice is always in an unused part of the carry buffer and it will always have zero gradients. Clearly, JAX cannot know that so it needs to save them, but it seems like it should be possible to create a general implementation of fast_insert that also saves the tangents at that single index for use in the backward pass.

Re: fast_pop vs slow_pop Created a new issue

Appreciate the feedback, and thank you for your help!

patrick-kidger commented 1 year ago

Ach -- I suppose this is what comes of asking folks to try asking private implementation details! eqxi.scan is assuming the implementation is polymorphic wrt whether it is ran with buffers or not.

The body function is actually ran twice. Once just to find out the shape of the extensive output (in this case None), so that we know how much memory we need to allocate to save the output. In this form it is called without buffers. Then the body function is ran again "for real", with buffers.

Hmm. I suspect that solving at the Equinox level means writing functions like

def dynamic_slice(operand, start_indices, slice_sizes):
    if isinstance(operand, _Buffer):
        operand = x._array
    return lax.dynamic_slice(operand, start_indices, slice_sizes)

And then using that instead. You could do something similar here (hasattr(operand, "_array") should suffice for the check).

I appreciate this is all fairly non-ideal. I'm afraid you're bumping against one of the sharpest edges in JAX/XLA. So Equinox is doing what it can here. 😅


As for your actual use case -- I find a lot of this surprising as well. For now, performing in-place updates inside loops is still a case where we need to hold the compiler's hand.