Open karunrao97 opened 1 year ago
If I'm reading this correctly (your example is quite large!), I think this is another example of #10197. This is due to a bug in the XLA compiler, which indeed triggers spurious copies when doing a grad(scan(slice + in-place-update))
. Can you check if this workaround works for you?
Thank you for your reply!
I tried replacing lax.scan
with equinox.internal.scan
as follows:
import equinox.internal as eqxi
...
state, _ = eqxi.scan(
partial(update_fn, mode=mode),
init_state,
(row_indices, col_indices, insert_values**w),
buffers=lambda x: x,
kind="checkpointed",
checkpoints="all",
)
But it seems that eqxi.scan
is incompatible with lax.dynamic_slice
, and I get the following stack trace:
update = lax.dynamic_slice(state, start_indices=(row, col + 1), slice_sizes=(1, N_COLS - 1))
AttributeError: '_Buffer' object has no attribute 'ndim'
I cannot use static slices because that blows up compile time to several hours in my case. I also need to do multiple operations including matmuls on slices of the buffer, not only .at[...].set(...)
, which is only here for the MWE.
Is there a key insight I can incorporate from eqxi.scan
into my code? I followed the eqxi.scan
code to https://github.com/patrick-kidger/equinox/blob/76a13bd100a8127c7e5176f2dee9d36c837b46ec/equinox/internal/_loop/common.py#L220, at which point I'm a bit lost and probably need to read up more on JAX internals. I don't need a while
loop, and I don't think I need to prevent cond
turning into select
on vmap
. Is there something simpler that might work for my use case?
Thanks again!
Can you try doing lax.dynamic_slice(state._array, ...)
? (Here, state
is the buffer object.) That's private API right now, but we can add a public API for dynamic slicing if it'd help. Though as noted, the Equinox workaround for this XLA bug only works if you do precisely one read, whether that is via state[i]
or dynamic_slice(state._array, ...)
-- if you have multiple reads in your actual use case, then you'll need to rewrite things to do them all in one go.
As for what eqxi.scan
is doing under the hood -- it actually combines several fixes for various JAX/XLA issues! (You've spotted the vmap(cond)
one, which is another one.)
To explain the important part of what's going on here. When you have the program
y = xs[i]
xs = xs.at[j].set(y)
then on the backward pass this becomes
grad_y = grad_xs[j]
grad_xs = grads_xs.at[j].set(0)
grad_xs = grad_xs.at[i].set(grad_y)
and unfortunately, XLA thinks that the .at[].set(0)
overwrites the memory in which grad_y
is held, so it erroneously makes a copy of all of grad_xs
. This bug triggers whenever you have a scan containing >=1 read and >=2 writes, all from the same buffer.
As such, the key insight in the Equinox code is to create a custom primitive (the one you've linked to) to replace the xs.at[j].set(y)
call, which when backpropagated skips the .at[].set(0)
. In doing so we get down to just a single write on the backward pass, and we no longer hit the bug. In terms of program correctness, this will be valid provided we never again read from grad_xs[j]
on the backward pass, which will itself be valid provided we never write to xs[j]
twice on the forward pass. Thus the requirement listed at the end of the workaround above.
I appreciate this is a fairly complicated problem to try and work around. :/
Thank you very much for the detailed insights.
I tried using state._array
, but it throws the following error:
AttributeError: DynamicJaxprTracer has no attribute _array
I inspected state
, and it is indeed of type Traced<ShapedArray(float32[100,1000])>with<DynamicJaxprTrace(level=5/0)>
, so it seems the buffer wrapping happens after this?
Nevertheless, reading through the equinox code that circumvents this issue was very helpful. I managed to write my own custom primitive for my insert operation using similar ideas, and I think I see the problem now. Do I understand correctly that the work-around using eqxi.scan buffers only works if each element of the buffer is written to only once throughout the entire scan? That is not the case for me -- I have unique index writes for each step of the scan, but across the scan there will be overlap.
I reduced the original MWE to something simpler below, which includes a working solution (fast_insert
) even in the case where there are overlaps in write indices across steps. However, this doesn't work for my actual use case, which includes more writes per step -- still unique per step, but I cannot do them in a single slice update. I believe the bad_insert
function does something similar to the buffers in eqxi.scan, and it works for a small number of N_STEPS
while there are no repeated rows, but as soon as a row is repeated, it produces incorrect gradients. Both slow_insert
and fast_insert
work correctly, though the latter seems to be able to avoid copying the carry. I don't quite understand why slow_insert
cannot update the carry in-place. This looks like a bug?
Note that while I'm using custom_vjp
below, implementing a primitive (similar to how it's done in equinox) yields the same results, so I omitted that for the sake of brevity.
import timeit
from functools import partial
import jax
import numpy as np
from jax import lax
from jax import numpy as jnp
N_STEPS = 1_000
N_ROWS, N_COLS = 100, 500
def simple_insert(carry, inputs):
row, col, value = inputs
update = jnp.concatenate([jnp.array([[value]]), lax.dynamic_slice(carry, start_indices=(row, col), slice_sizes=(1, N_COLS))], axis=-1)
return lax.dynamic_update_slice(carry, update=update, start_indices=(row, col))
@jax.custom_vjp
def bad_insert(carry, inputs):
return simple_insert(carry, inputs)
def _bad_insert_fwd(carry, inputs):
row, col, _ = inputs
return bad_insert(carry, inputs), (row, col)
def _bad_insert_bwd(residuals, ct_out):
row, col = residuals
return ct_out, (0, 0, ct_out[row, col]) # Using 0 because None throws "ValueError: safe_zip() argument 2 is shorter than argument 1" in _flatten_bwd
@jax.custom_vjp
def slow_insert(carry, inputs):
return simple_insert(carry, inputs)
def _slow_insert_fwd(carry, inputs):
row, col, _ = inputs
return slow_insert(carry, inputs), (row, col)
def _slow_insert_bwd(residuals, ct_out):
row, col = residuals
ct_value = ct_out[row, col].copy()
update = lax.dynamic_slice(ct_out, start_indices=(row, col + 1), slice_sizes=(1, N_COLS))
ct_out = lax.dynamic_update_slice(ct_out, update=update, start_indices=(row, col))
ct_out = ct_out.at[row, col + N_COLS].set(0.0)
return ct_out, (0, 0, ct_value)
@jax.custom_vjp
def fast_insert(carry, inputs):
return simple_insert(carry, inputs)
def _fast_insert_fwd(carry, inputs):
row, col, _ = inputs
return fast_insert(carry, inputs), (row, col)
def _fast_insert_bwd(residuals, ct_out):
row, col = residuals
ct_out = ct_out.at[row, -1].set(ct_out[row, col])
update = lax.dynamic_slice(ct_out, start_indices=(row, col + 1), slice_sizes=(1, N_COLS))
ct_out = lax.dynamic_update_slice(ct_out, update=update, start_indices=(row, col))
ct_out = ct_out.at[row, col + N_COLS].set(0.0)
return ct_out, (0, 0, ct_out[row, -1])
bad_insert.defvjp(_bad_insert_fwd, _bad_insert_bwd)
slow_insert.defvjp(_slow_insert_fwd, _slow_insert_bwd)
fast_insert.defvjp(_fast_insert_fwd, _fast_insert_bwd)
@partial(jax.jit, static_argnames="fn")
def wrap(fn, *args):
return fn(*args), None
@partial(jax.jit, static_argnames="fn")
def loss(w, fn, init_state, target_state, rows, cols, values):
state, _ = lax.scan(partial(wrap, fn), init_state, (rows, cols, values**w))
return ((target_state - state[:, :N_COLS]) ** 2).mean()
if __name__ == "__main__":
np.random.seed(42)
init_state = jnp.zeros(shape=(N_ROWS, N_COLS * 2 + 1), dtype=jnp.float32)
loss_kwargs = dict(
target_state=jnp.asarray(np.random.random(size=(N_ROWS, N_COLS)).astype(np.float32)),
rows=jnp.asarray(np.random.randint(low=0, high=N_ROWS, size=N_STEPS, dtype=np.int32)),
cols=jnp.asarray(np.random.randint(low=0, high=N_COLS, size=N_STEPS, dtype=np.int32)),
values=jnp.asarray(np.random.random(size=N_STEPS).astype(np.float32)),
)
fns = simple_insert, bad_insert, slow_insert, fast_insert
loss_fns = {fn.__name__: jax.jit(partial(loss, fn=fn, init_state=init_state.copy(), **loss_kwargs)) for fn in fns}
grad_fns = {k: jax.jit(jax.grad(v)) for k, v in loss_fns.items()}
# Run all functions once to JIT-compile
g_true = None
w = jnp.float32(2)
for k in loss_fns.keys():
loss_fns[k](w).block_until_ready()
g = grad_fns[k](w).block_until_ready()
if g_true is None:
g_true = g
print(f"{g_true=}")
elif not np.isclose(g, g_true):
print(f"{k} is incorrect, {g=}")
else:
print(f"{k} is correct")
# Run benchmarks
n = 10
for k in loss_fns.keys():
print()
print(k)
print(f" loss: {timeit.timeit(lambda: loss_fns[k](w).block_until_ready(), number=n) / n : .5f}")
print(f" grad: {timeit.timeit(lambda: grad_fns[k](w).block_until_ready(), number=n) / n : .5f}")
Running it produces the following output:
g_true=Array(0.00052476, dtype=float32)
bad_insert is incorrect, g=Array(0.0017259, dtype=float32)
slow_insert is correct
fast_insert is correct
simple_insert
loss: 0.00029
grad: 0.02732
bad_insert
loss: 0.00029
grad: 0.00030
slow_insert
loss: 0.00029
grad: 0.05066
fast_insert
loss: 0.00030
grad: 0.00044
Looking at the slow_insert
example from the previous post, I believe this is not even an issue with grad, but just with scan + slice update.
Here's a (hopefully) simpler MWE to illustrate the issue without grad:
import timeit
from functools import partial
import jax
import numpy as np
from jax import lax
from jax import numpy as jnp
N_STEPS = 1_000
N_ROWS, N_COLS = 100, 500
@partial(jax.jit, donate_argnames="carry")
def slow_pop(carry, inputs):
row, col = inputs
value = carry[row, col].copy()
update = lax.dynamic_slice(carry, (row, col + 1), (1, N_COLS))
carry = lax.dynamic_update_slice(carry, update, (row, col))
return carry, value
@partial(jax.jit, donate_argnames="carry")
def fast_pop(carry, inputs):
row, col = inputs
carry = carry.at[row, -1].set(carry[row, col])
update = lax.dynamic_slice(carry, (row, col + 1), (1, N_COLS))
carry = lax.dynamic_update_slice(carry, update, (row, col))
return carry, carry[row, -1]
def scan_fn(fn, init_state, rows, cols):
state, out = lax.scan(fn, init_state.copy(), (rows, cols))
state.block_until_ready()
out.block_until_ready()
return state, out
if __name__ == "__main__":
np.random.seed(42)
init_state = jnp.asarray(np.random.random((N_ROWS, N_COLS * 2 + 1)).astype(jnp.float32))
rows = jnp.asarray(np.random.randint(0, N_ROWS, size=N_STEPS, dtype=np.int32))
cols = jnp.asarray(np.random.randint(0, N_COLS, size=N_STEPS, dtype=np.int32))
# Run both functions once to JIT-compile
state1, out1 = scan_fn(slow_pop, init_state, rows, cols)
state2, out2 = scan_fn(fast_pop, init_state, rows, cols)
assert (out1 == out2).all()
assert (state1[:, :-1] == state2[:, :-1]).all()
# Run benchmarks
n = 10
for fn in slow_pop, fast_pop:
print(f"{fn.__name__}: {timeit.timeit(partial(scan_fn, fn, init_state, rows, cols), number=n) / n : .5f}")
And the output:
slow_pop: 0.05314
fast_pop: 0.00086
slow_pop
does the intuitive thing -- copy the element at (row, col)
into a separate value
, then shift remaining elements in col
left, return both the new carry
and the value
. fast_pop
takes a more convoluted approach -- it first writes the value at (row, col)
into (row, -1)
, where the last index is reserved for this operation, then does the shift left and finally returns the new carry
and the element at carry[row, -1]
as the value. It seems like the key here is to only return elements of the last state of carry
, instead of extracting some values from an earlier state of carry
, which seems to disrupt the in-place optimizations.
Is this expected behaviour? Also, is it perhaps better to close this issue and open a separate one? The title of this one might be misleading.
Re: error with state._array
It looks like you're not setting eqxi.scan(..., buffers=...)
to return this variable. (Although I think you did do so in https://github.com/google/jax/issues/17640#issuecomment-1725025108, so I think your code must have changed since then.)
Re: your use case
I have unique index writes for each step of the scan, but across the scan there will be overlap.
That is, you're overwriting information. In this case there's nothing anyone can do: you're trying to delete information that's needed for the backward pass. In this case JAX/XLA is doing the correct thing by making a copy -- it's silently giving you correctness rather than speed.
This is likely the reason for your statement "but as soon as a row is repeated, it produces incorrect gradients".
Re: fast_pop vs slow_pop
Okay, this is very interesting! Indeed, doing the read after all of the writes seems to avoid the XLA bug. So copying the information you need into another part of the buffer is a nice trick.
Indeed I would suggest opening a new issue for this particular thing, with just the simpler MWE.
Thanks again for your reply!
Re: eqxi.scan
I think I am using it correctly, but perhaps trying to access ._array
fails at trace time?
As noted above, though, unfortunately this solution doesn't seem applicable to my use case anyway.
Re: my use case
I agree that JAX will need to save information that's overwritten. But each step here shifts a slice right by one index, and thus overwrites a single element immediately after the slice. I expected JAX to copy the tangents at that single index alone and shift the other tangents right, but it seems to be copying the entire carry (and/or perhaps the tangents for the entire carry). In fact, I'm pretty sure it is actually shifting the tangents under the hood, because my fast_insert
implementation in the second example above does not actually save or shift any tangents in the forward pass, while it does shift the cotangents left in the backward pass in order to obtain the same gradients as simple_insert
, although with much better performance. I used the same trick there as in fast_pop
above, and in fact fast_pop
is essentially the same as the backward pass of fast_insert
in the previous example. Note that in my particular use case, I can get away with not saving any tangents in the forward pass because the element after the slice is always in an unused part of the carry buffer and it will always have zero gradients. Clearly, JAX cannot know that so it needs to save them, but it seems like it should be possible to create a general implementation of fast_insert
that also saves the tangents at that single index for use in the backward pass.
Re: fast_pop vs slow_pop Created a new issue
Appreciate the feedback, and thank you for your help!
Ach -- I suppose this is what comes of asking folks to try asking private implementation details! eqxi.scan
is assuming the implementation is polymorphic wrt whether it is ran with buffers or not.
The body function is actually ran twice. Once just to find out the shape of the extensive output (in this case None
), so that we know how much memory we need to allocate to save the output. In this form it is called without buffers. Then the body function is ran again "for real", with buffers.
Hmm. I suspect that solving at the Equinox level means writing functions like
def dynamic_slice(operand, start_indices, slice_sizes):
if isinstance(operand, _Buffer):
operand = x._array
return lax.dynamic_slice(operand, start_indices, slice_sizes)
And then using that instead. You could do something similar here (hasattr(operand, "_array")
should suffice for the check).
I appreciate this is all fairly non-ideal. I'm afraid you're bumping against one of the sharpest edges in JAX/XLA. So Equinox is doing what it can here. 😅
As for your actual use case -- I find a lot of this surprising as well. For now, performing in-place updates inside loops is still a case where we need to hold the compiler's hand.
Description
Hi,
I am running a scan over several thousand steps with a 2D array as the carry, where each step consists of (among other operations):
I know the maximum number of columns (let's call it N_COLS) ahead of time, so we can do the above operation by allocating a 2D array with width twice N_COLS, and always shifting a slice of size N_COLS left and right for the delete and insert, respectively. This works well for the forward pass, but I observe very slow gradient computations (a factor of 100x slower). I think the gradient computation is triggering a copy of the entire carry instead of updating it in-place, and I came up with a fairly strange work-around (called "combo2" in the below MWE) that seems to update in-place and achieve reasonable performance, but only when using JVP. I would like to know if there is a better way of implementing this, and also whether there is a way to achieve the same using VJP.
Here is a minimal working example:
Here is the output with
N_ROW_REPEATS = 1
:And here is the output with
N_ROW_REPEATS = 10
:I would have expected one of the two "slice" modes to be most efficient, and this does seem to be the case for the forward pass. However, given the 100x slower gradient computation (using both VJP and JVP), I attempted some other methods. The "scatter" mode uses advanced-indexing with the
.at[...].set(...)
notation, which is slower on the forward pass (as expected), but the gradient computation is much faster (when using JVP) than the "slice" modes. The faster gradient compute relies onunique_indices=True
, and without that, it is again very slow -- seems we might be benefitting from this code path: https://github.com/google/jax/blob/bcc545a69232e983ae31b0395f4972979f2789c0/jax/_src/lax/slicing.py#L2297The "combo" modes utilize the above insight to combine a small scatter with a dynamic slice update, but this seems to only work well in the case ("combo2") that the scatter update contains more than one element -- I guess "combo1" reduces to to a slice update instead of a scatter.
As seen in the case where
N_ROW_REPEATS = 10
, as we enlarge the carry, all modes scale terribly with VJP, and all modes except "scatter" and "combo2" also scale terribly with JVP, even though we're only slicing into a fraction of the rows in carry (and the slices are identical to those withN_ROW_REPEATS = 1
) -- this suggests that these modes are triggering a copy of the entire carry, and not updating it in-place as intended.Some questions:
I'm new to JAX, so apologies if I'm missing something obvious. Thank you for your help!
What jax/jaxlib version are you using?
jax v0.4.14, jaxlib v0.4.14
Which accelerator(s) are you using?
CPU
Additional system info
python v3.9.9, linux
NVIDIA GPU info
No response