Open albertz opened 1 year ago
This touches on lots of interesting points!
First of all, note that JAX doesn't support dynamic shapes. It's not actually possible to implement "naive variant 1" under JAX's model of computation.
Next, as Matt has touched on. scan
basically does "naive variant 2", and also does the correct thing under autodiff. The reason for this is that the autodiff happens before the lowering to "naive variant 2".
Let's unpack that. Here's a simple program:
import jax
import jax.lax as lax
import jax.numpy as jnp
def f(x):
def body(carry, _):
return carry + 1, carry + 1
_, out = lax.scan(body, x, xs=None, length=5)
return jnp.sum(out)
forward_jaxpr = jax.make_jaxpr(f)(1.)
print("forward computation", forward_jaxpr)
Now, here's the result:
forward computation { lambda ; a:f32[]. let
_:f32[] b:f32[5] = scan[
jaxpr={ lambda ; c:f32[]. let
d:f32[] = add c 1.0
e:f32[] = add c 1.0
in (d, e) }
length=5
linear=(False,)
num_carry=1
num_consts=0
reverse=False
unroll=1
] a
f:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
g:f32[] = reduce_sum[axes=(0,)] f
in (g,) }
This shows how JAX interprets our Python program. We see that JAX has recorded the scan
as a single primitive bind -- and hasn't yet lowered to a while-with-DUS.
Now let's look at the backward pass:
backward_jaxpr = jax.make_jaxpr(jax.grad(f))(1.)
print("forward+backward computation", backward_jaxpr)
this gives:
forward+backward computation { lambda ; a:f32[]. let
_:f32[] b:f32[5] = scan[
jaxpr={ lambda ; c:f32[]. let
d:f32[] = add c 1.0
e:f32[] = add c 1.0
in (d, e) }
length=5
linear=(False,)
num_carry=1
num_consts=0
reverse=False
unroll=1
] a
f:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
_:f32[] = reduce_sum[axes=(0,)] f
g:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
h:f32[5] = convert_element_type[new_dtype=float32 weak_type=True] g
i:f32[] = scan[
jaxpr={ lambda ; j:f32[] k:f32[]. let l:f32[] = add_any j k in (l,) }
length=5
linear=(True, True)
num_carry=1
num_consts=0
reverse=True
unroll=1
] 0.0 h
in (i,) }
We've differentiated our scan. This produces two scans: one for the forward iteration, one for the backward iteration, and the important bit is that we still haven't specified how we're implementing our scan
-- i.e. there still isn't a while-with-DUS.
It's only once you've done all the autodiff that this actually gets lowered to a while-with-DUS. We can inspect that (using some private internals) like so:
import equinox.internal as eqxi
eqxi.primitive_finalisations[lax.scan_p] = jax._src.lax.control_flow.loops._scan_impl
forward_lowered_jaxpr = eqxi.finalise_jaxpr(forward_jaxpr)
print("forward lowered computation", forward_lowered_jaxpr)
print("")
backward_lowered_jaxpr = eqxi.finalise_jaxpr(backward_jaxpr)
print("forward+backward lowered computation", backward_lowered_jaxpr)
this produces:
forward lowered computation { lambda ; a:f32[]. let
b:f32[] = empty[dtype=float32]
c:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] b
_:i32[] _:f32[] d:f32[5] = while[
body_jaxpr={ lambda ; e:i32[] f:f32[] g:f32[5]. let
h:f32[] = add f 1.0
i:f32[] = add f 1.0
j:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] i
k:bool[] = lt e 0
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
m:i32[] = add l 5
n:i32[] = select_n k e m
o:f32[5] = dynamic_update_slice g j n
p:i32[] = add e 1
in (p, h, o) }
body_nconsts=0
cond_jaxpr={ lambda ; q:i32[] r:f32[] s:f32[5]. let
t:bool[] = lt q 5
in (t,) }
cond_nconsts=0
] 0 a c
u:f32[] = reduce_sum[axes=(0,)] d
in (u,) }
forward+backward lowered computation { lambda ; a:f32[]. let
b:f32[] = empty[dtype=float32]
c:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] b
_:i32[] _:f32[] d:f32[5] = while[
body_jaxpr={ lambda ; e:i32[] f:f32[] g:f32[5]. let
h:f32[] = add f 1.0
i:f32[] = add f 1.0
j:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] i
k:bool[] = lt e 0
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
m:i32[] = add l 5
n:i32[] = select_n k e m
o:f32[5] = dynamic_update_slice g j n
p:i32[] = add e 1
in (p, h, o) }
body_nconsts=0
cond_jaxpr={ lambda ; q:i32[] r:f32[] s:f32[5]. let
t:bool[] = lt q 5
in (t,) }
cond_nconsts=0
] 0 a c
_:f32[] = reduce_sum[axes=(0,)] d
u:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
v:f32[5] = convert_element_type[new_dtype=float32 weak_type=True] u
_:i32[] w:f32[] = while[
body_jaxpr={ lambda ; x:f32[5] y:i32[] z:f32[]. let
ba:i32[] = sub 5 y
bb:i32[] = sub ba 1
bc:bool[] = lt bb 0
bd:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bb
be:i32[] = add bd 5
bf:i32[] = select_n bc bb be
bg:f32[1] = dynamic_slice[slice_sizes=(1,)] x bf
bh:f32[] = squeeze[dimensions=(0,)] bg
bi:f32[] = add_any z bh
bj:i32[] = add y 1
in (bj, bi) }
body_nconsts=1
cond_jaxpr={ lambda ; bk:i32[] bl:f32[]. let
bm:bool[] = lt bk 5
in (bm,) }
cond_nconsts=0
] v 0 0.0
in (w,) }
In these jaxprs, we can now see the while
, dynamic_update_slice
, and dynamic_slice
operations appearing -- this is the actual implementation that gets compiled.
So, JAX avoids inefficient memory consumption by doing autodiff at a higher level representation (the scan
primitive), rather than doing it on the lower-level while-slice representation.
As you've noticed, you can't do this with a lax.while_loop
and doing the update-slicing yourself. JAX doesn't actually support reverse-mode autodifferentiation of lax.while_loop
. The reason for this is that a while_loop
may unroll for some unknown number of iterations, so we'd have to store an arbitrary amount of forward information for our backward pass -- and as noted above, JAX doesn't support dynamically-sized arrays! (More precisely, I believe XLA doesn't support dynamic memory allocation.)
You could do this by composing lax.scan
and update-slices (at least for a fixed number of iterations). Unfortunately, I think JAX will actually exhibit the O(T^2) scaling that you expect from the naive analysis. The extensive output of a scan is handled as we saw in part 1, but the carry gets no special treatment. (In general you really do want to record a copy, after all.) XLA isn't usually very smart about avoiding copies from update-slices and scatters, and I really recommend avoiding explicit usage of these unless you have a good mental model for what it knows how to optimise and what it doesn't.
(The most common example of this footgun is when vmap'ing a while, with a body function that uses a scatter. Even without autodiff this will actually incur a needless copy, blech.)
We've seen that JAX is smart enough to handle extensive outputs of scans, but that:
The good news is that JAX actually has enough tools that we can fix both of these issues ourselves, as a user! (It's unfortunate that these footguns exist in the first place of course -- it'd be nice if JAX fixed these issues itself at some point.)
Take a look at equinox.internal.while_loop
. Equinox is available here. This is basically a better version of lax.while_loop
!
This support reverse mode autodiff by using a fixed number of checkpoints, and then rematerialising intermediate computations as required. (If you know of the "treeverse" algorithm, then this is an extension to the unbounded-number-of-steps case.) This fixes issue 1.
It also supports a buffers=
argument. This is used to explicitly list all of the elements of the carry which behave in this update-slice manner. These are then handled correctly around (a) autodiff (in the same way as lax.scan
, but for the carry rather than just the extensive output) and (b) in-place updates (which fixes issue 2).
There's also a convenience equinox.internal.scan
which wraps this so as to provide a checkpointed scan implementation, compatible with the jax.lax.scan
API. (You'll occasionally see folks do a lax.scan(jax.checkpoint(f), ...)
to emulate this. equinox.internal.scan
is similar, but much smarter: it uses treeverse-style checkpointing rather than checkpointing every step.)
Phew, that was rather a lot! Does that all make sense / do you have any follow-up questions?
Thanks a lot for the very detailed write-up!
JAX doesn't support dynamically-sized arrays! (More precisely, I believe XLA doesn't support dynamic memory allocation.)
I assume this is for efficiency reasons, maybe also simplicity reasons on XLA side? Are there any plans to loosen this restriction on XLA side, and then also on JAX side?
So, if I need this, what are my options?
One obvious use case: Implement beam search for some encoder-decoder like model, or just a language model. You don't know in advance how long it is going to be. For this, you need dynamically-sized arrays. Or what are the possible workarounds? Setting some arbitrary upper limit for the target length? But if you set e.g. length=10_000
, would it mean that it always iterates 10k steps now, even when the real length might only be 10 labels?
I wonder a bit about this. Isn't this a quite common case?
Despite the dynamic size issue: Wouldn't sth like TensorArray
also be useful for JAX? That way, it would be simple to reimplement sth like scan
by the user in an efficient way with O(T) memory and runtime.
I understand there's ongoing work (in both JAX and XLA) to implement "bounded dynamism"/"dynamic shapes", for which the size of the array is dynamic but its size has some known upper limit. AFAICT that's not landing any time soon though.
For working around this: indeed, typically we just set some upper limit. (This is usually not too big a deal, our available memory is already setting an upper limit anyway.) If you use something like a scan(..., length=10_000)
then indeed this will always iterate to 10k steps. But if you use eqxi.while_loop(..., max_steps=10_000)
then it will early-exit. Either way you'll use O(10k) memory, but in the latter case you'll only make as many steps as you actually need.
Another possibility is to do the dynamic part of the iteration in Python, rather than within the confines of JIT. This often makes sense if you're doing it at the top level (above the level at which you're doing autodiff etc.)
I don't think this is that common as a use case. In many respects I think this is quite useful, as it ensures that we write our program in ways we can efficiently transform/lower to all backends/etc. For example I know the Julia folks have had some recent issues in which CUDA.jl only supports static memory allocation, but many of their programs were written assuming dynamic memory allocation is possible. This made it hard to compose these two parts of their ecosystem.
I think TensorArray
is basically equivalent to the ongoing dynamic shapes work.
In the past, I implemented beam search in TensorFlow in a way that it was differentiable, and we used it for max expected BLEU or min expected WER training. So I don't really see any reasonable way now to do this in JAX. Or maybe currently the only ugly way I see is:
Do the search with an outer loop in pure Python, to get the search space lattice, and to know the real max seq length, or maybe even directly the best N sequences (or a graph/lattice representing best sequences). Then a second pass using the given sequences, so now the seq length is known.
This sounds already complicated, but when thinking more about it, you realize that it is even more problematic: You don't want to recompute the encoder twice, but you still want that backprop goes through the encoder for the second pass. How would you even do that?
All of this is not that uncommon for tasks operating on sequences, like speech recognition or machine translation.
In any case, any preallocation based on some given max upper seq length would be problematic. In practice, you would want that it just uses as much memory as possible and then you get an OOM if it needs more. Also, for the case of short sequences, it should still be fast. Even if the num calc steps are small, allocating such a big tensor in each iteration would make it slow.
In our case, also our batch size is dynamic, because we have sequences of very different lengths. So we have some very small batch sizes with very long sequences, but also very big batch sizes with very short sequences.
Btw, I tried some implementation using scan
, where I do gather and dynamic index update only via the carry, and then doing backprop, and measuring the memory consumption. The x-axis is the seq length, and the y axis is memory consumption (RSS).
It actually looks linear?
This is my code. Maybe it is wrong?
import jax
import psutil
import matplotlib.pyplot as plt
import os
batch_dim = 10
feature_dim = 5
def test_scan(time_dim: int):
def func(xs):
def body(state, _):
i, ys = state
x = xs[i]
x_ = x * (i ** 0.5)
ys = jax.lax.dynamic_update_index_in_dim(ys, x_, i, axis=0)
return (i + 1, ys), None
i = 0
ys = jax.numpy.zeros((time_dim, batch_dim, feature_dim))
(_, ys), _ = jax.lax.scan(body, init=(i, ys), xs=None, length=time_dim)
y = jax.numpy.sum(ys)
return y
rnd_key = jax.random.PRNGKey(42)
xs = jax.random.uniform(rnd_key, (time_dim, batch_dim, feature_dim))
grad_xs = jax.grad(func)(xs)
return grad_xs
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
jax.config.update('jax_platform_name', 'cpu')
xs = list(range(5, 10_000, 100))
ys = []
for n in xs:
y = test_scan(n)
y.block_until_ready()
mem = psutil.Process().memory_info().rss
print(f"n={n}, mem={mem}")
ys.append(mem)
fig, ax = plt.subplots()
ax.plot(xs, ys)
plt.show()
Edit Ah yes I think this is wrong. I'm not really measuring the mem consumption during the backprop. I need to measure at the end of the forward pass, just before backprop starts. Not sure how I can hook that point.
If you need to differentiate through a variably-sized computation -- which is what I get from what you're saying -- then eqxi.while_loop
is probably the appropriate tool.
Here's an exampe of the O(N^2) behaviour: #8192
You could do this by composing lax.scan and update-slices (at least for a fixed number of iterations). Unfortunately, I think JAX will actually exhibit the O(T^2) scaling that you expect from the naive analysis. The extensive output of a scan is handled as we saw in part 1, but the carry gets no special treatment. (In general you really do want to record a copy, after all.)
Actually, no; you still get O(T) memory scaling even if you don't use the extensive inputs and outputs. Here's an example:
import jax
import jax.numpy as jnp
import jax.ad_checkpoint
xs = 1 + jnp.arange(3.)
def f(z):
xs_ = z * xs
ys = jnp.zeros_like(xs)
def body(carry, _):
i, c, xs, ys = carry
new_c = jnp.sin(jnp.sin(xs[i]) * c)
new_ys = jax.lax.dynamic_update_index_in_dim(ys, new_c, i, 0)
new_carry = (i + 1, new_c, xs, new_ys)
return new_carry, None
(_, y, _, ys), _ = jax.lax.scan(body, (0, z, xs_, ys), None, length=3)
return jnp.sum(y * ys)
def f2(z):
xs_ = z * xs
def body(c, x):
new_c = jnp.sin(jnp.sin(x) * c)
return new_c, new_c
y, ys = jax.lax.scan(body, z, xs_)
return jnp.sum(y * ys)
print('=== first-order AD ===')
print('carry-only loop')
jax.ad_checkpoint.print_saved_residuals(f, 3.)
print()
print('extensive inputs/outputs loop')
jax.ad_checkpoint.print_saved_residuals(f2, 3.)
print()
=== first-order AD ===
carry-only loop
f32[3] from a constant
f32[3] output of broadcast_in_dim from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
i32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
i32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:108 (f)
extensive inputs/outputs loop
f32[3] from a constant
f32[] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
f32[3] output of scan from /usr/local/google/home/mattjj/packages/jax/ckpt.py:117 (f2)
Notice how nothing is of size T^2, i.e. nothing is shaped like [3, 3, ...]
. No need for fancy optimizations or anything! This is all at the jaxpr level, before XLA. We just don't save extra stuff to begin with.
JAX only saves residuals when primal values interact with tangents in a first-order primitive application, i.e. it only saves them when they're really needed. That means we don't need to save anything to do with indexing into the big thing to get the small thing. We only save residuals when we hit e.g. multiplies, and those only involve small things.
So scan
's scanned-over inputs and outputs don't exist to reduce the memory needed for residuals. Instead they exist because otherwise within each backward pass body iteration, when we transpose the dynamic_slice
we have to create a big array of zeros, just to write one nonzero slice into it, and then add that sparse-but-represented-as-dense result into one of the carry elements. XLA can sometimes optimize away simple versions of that pattern, but it can get complex with e.g. nesting. So to keep things sparse and not require this optimization, we encourage using scan
's scanned-over inputs/outputs.
There's much more that can be said about loops, but I think that covers this set of questions from the OP:
So, my question is: Is JAX scan really inefficient like I described, esp for the case of backprop? Or if not, why not? How does it avoid the O(T^2) memory consumption in backprop? Is this some automatic optimization? How does it work?
I think the other questions may have been covered by @patrick-kidger already, though I didn't actually check! I just wanted to talk about scan
:P
In simple cases like this, yes. In practice there are programs which do spuriously get O(n^2) scaling, due to the particular way their computation is expressed.
For example adding new_ys = jnp.where(i > -10, new_ys, ys)
after the DUS will introduce an O(n^2) residual. This case is very common when working with a batch of data, each element of which iterates for a variable length of time.
Similarly instead constructing new_c
via new_c = jnp.dot(ys, xs)
will introduce an O(n^2) residual. We don't actually need to keep multiple copies of ys
around, as ys
is initialised to zero in those regions which have not yet been written to. This case is common with "triangular" computations, in which each step depends on all previous steps.
(Honourable mention: nested loops as you mention sometimes mean that transpose-of-DS really does create an array of zeros -- this introduces a spurious O(n^2) runtime.)
I can see that we're talking about slightly different things. You're thinking about the reimplementation of specifically extensive outputs. Whilst I didn't originally discuss specifics I'm mostly trying to emphasise that real-world uses of in-place-updates are silently dangerous!
Thanks again for the answers. I think the examples from @patrick-kidger are clear to me, that they introduce O(n^2) memory consumption. What @mattjj writes though was unexpected to me. In this line:
new_ys = jax.lax.dynamic_update_index_in_dim(ys, new_c, i, 0)
You only get O(n) if it can be sure that it does not need to store the old ys
for backprop. I think I see now that this is maybe not actually so difficult to determine. I still wonder a bit about whether there are problems when you do unexpected things with the index i
here. E.g. what happens when you write the same index twice?
ys = jax.lax.dynamic_update_index_in_dim(ys, c1, 0, 0)
ys = jax.lax.dynamic_update_index_in_dim(ys, c2, 0, 0)
So, backprop on the final sum(ys)
, you would have a error signal for the whole ys
, i.e. grad_ys
of shape (T,B,F), and then you get a gradient for c2
via grad_ys[0]
. But what is the error signal to c1
? It should be zero, not grad_ys[0]
. So the gradient of dynamic_update_index_in_dim
would also need to update grad_ys
, like so:
grad_c2 = grad_ys[0]
grad_ys = jax.lax.dynamic_update_index_in_dim(grad_ys, 0, 0, 0)
grad_c1 = grad_ys[0] # now zero
Actually c1
is really totally disconnected here but I assume it cannot really see that.
But now back to the normal case: It means you have those dynamic_update_index_in_dim
in backprop on grad_ys
in any case? I assume it will not really be too problematic, but it feels a bit wasted. But I'm not sure if you really could avoid that.
when we transpose the
dynamic_slice
we have to create a big array of zeros, just to write one nonzero slice into it, and then add that sparse-but-represented-as-dense result into one of the carry elements.
In TensorFlow, the gradient of tf.gather
will be an IndexedSlices
, so it is not sparse-but-represented-as-dense, but really sparse. That is an easy way to solve this problem. I actually would have expected that this is what you would also do here. So this is also somewhat unexpected that you don't do that. If you don't do that, then you still have O(n^2) runtime behavior here.
I was just checking that by extending my script, and yes, this seems to be the case, the runtime seems to be O(n^2):
In simple cases like this, yes. In practice there are programs which do spuriously get O(n^2) scaling, due to the particular way their computation is expressed.
Certainly you can write programs where you need to save a full copy of the carry for each step, but that wasn't the question. My point is simply that if you wrote a loop which only slices into the carry you don't get the quadratic memory scaling.
You only get O(n) if it can be sure that it does not need to store the old ys for backprop. I think I see now that this is maybe not actually so difficult to determine.
The old ys
are never needed there. We only save residuals when primals interact (multiplicatively) with tangents.
In TensorFlow, the gradient of tf.gather will be an IndexedSlices, so it is not sparse-but-represented-as-dense, but really sparse. That is an easy way to solve this problem. I actually would have expected that this is what you would also do here. So this is also somewhat unexpected that you don't do that. If you don't do that, then you still have O(n^2) runtime behavior here.
That's what we did in the original Autograd too. But in JAX we ultimately lower to HLO, and there aren't runtime-sparse data types in HLO.
Just to summarize, my points are:
scan
's extensive inputs/outputs or instead do the slicing into the carry yourself, the memory cost of saved residuals is linear not quadratic, contra the OP ("But for the standard case, it would need to store a copy of the full ys tensor in every iteration, for the use of backprop. So it means you get T times a copy of ys. This means O(T^2) memory consumption.").scan
's extensive inputs/outputs, the jaxpr-level runtime cost of transposition (i.e. if you define a cost model directly on jaxprs) is linear, not quadratic, but when we lower to HLO as efficiently as possible, i.e. using DynamicUpdateSlice, we rely on XLA to generate sparse in-place updates to avoid quadratic runtime behavior. Those optimizations may not always work and are also backend-dependent. (XLA:CPU hasn't been developed much for years, and was never a top priority the way the other backends were.)On the latter point, preserving efficiency while not relying on compiler optimizations (like JAX does) or sparse runtime data structures (like Autograd/TF do) was one of the points of the Dex paper.
Like Dex, the secret (i.e. undocumented) next-gen loop construct in JAX uses in-place effectful indexed writes/addupdate operations to ensure efficient transposition (at the jaxpr cost model level, i.e. at the moment we'd still ultimately lower to HLO) without requiring the indexing be built into the loop itself like scan
. Having effectful updates like that is the only good way to solve the problem that I know of.
I can't find a public version of the slides @dougalm and others have made to explain the design options, but here are a few inline:
Thanks for the summary. That clarifies some of the main points from my original post. I think most other points were also addressed already here.
Now I wonder though, why is this not efficient in the case of TF? I don't really remember anymore. I think I have seen quadratic memory consumption for similar code. Or maybe just quadratic runtime? Or maybe just very slow constant overhead? As we discussed, it actually has IndexedSlices
for the grads of the gather, so it should not have this problem. When using TensorArray
, there was a huge difference. Or maybe that is only true for some earlier TF versions? I always had the mental model that having TensorArray
solves all the problems of backprop efficiently, both for the loop inputs and loop outputs, and you would want to have sth like that. But now it seems it might not really be necessary? Or would it make sense for JAX to also have sth like TensorArray
?
I also still wonder about how to efficiently implement beam search decoding in a differentiable way with unknown sequence lengths. Or rather, it looks like this is just not possible currently in JAX? This would need dynamic shapes.
I was curious, and tried this PyTorch code:
def test_scan(time_dim: int):
xs = torch.rand((time_dim, batch_dim, feature_dim), requires_grad=True)
ys = torch.zeros((time_dim, batch_dim, feature_dim))
for i in range(time_dim):
x = xs[i]
x_ = x * (i ** 0.5)
ys[i] = x_
y = ys.sum()
y.backward()
return xs.grad
It should be equivalent to my JAX code above. It seems I get linear runtime and memory consumption for this.
I also still wonder about how to efficiently implement beam search decoding in a differentiable way with unknown sequence lengths. Or rather, it looks like this is just not possible currently in JAX? This would need dynamic shapes.
Can you just write it the same way you would in PyTorch, without using any control flow primitive? JAX can differentiate through Python loops. (The performance may not be great but then it's just a performance optimization problem.)
It would be a dynamic ending condition. Sth like if all(predicted_label == EOS): break
. And I would need to dynamically grow the tensors, sth like my naive variant 1. That's actually maybe another reason for TensorArray
, to support dynamic growing, without knowing final sequence length in advance.
So, that would work fine with JAX? I thought JAX does not support dynamic shapes?
TF provides the
TensorArray
to make automatic iteration and stacking efficient inscan
orwhile_loop
.The naive variant with gathering and concatenating or dynamic updates would be inefficient with backprop, because backprop would keep copies of the full array in each iteration. E.g. assume you are collecting
ys
of shape [T,B,F], and iterating over t in [0,...,T-1]. Now two possible naive variants:You don't know T in advance. You allocate the initial
ys
tensor of shape [0,B,F], and each iteration, you concatenate a new vector [B,F], extended as [1,B,F] to it, so in each step t, the currentys
is of shape [t,B,F].You know T in advance. You can allocate the initial
ys
tensor of shape [T,B,F]. In each iteration, you updateys[t]
(e.g.tensor_scatter_nd_update
).I have seen the concat variant being used for self-attention implementations.
I was checking JAX while_loop and scan and it seems it does not have
TensorArray
but instead usesdynamic_index_in_dim
/dynamic_slice
anddynamic_update_index_in_dim
/dynamic_update_slice
(liketensor_scatter_nd_update
), which is like variant 2.Without considering backprop, variant 2 can be efficient if it would update inplace, actually more so than
TensorArray
. But if it does not update inplace for some reason, you get O(T^2) runtime. Variant 1 would also likely lead to O(T^2) runtime, unless it can be very clever and having preallocated a tensor which is bigger. Then it might get away with O(T log T) runtime, similar to C++std::vector
. But I very much doubt that.When considering backprop, it is much worse, unless there are some clever optimizations happening. But for the standard case, it would need to store a copy of the full
ys
tensor in every iteration, for the use of backprop. So it means you get T times a copy ofys
. This means O(T^2) memory consumption.With TF
TensorArray
, this is not the case, as each tensorys[t]
is treated separately. It is efficient and only has O(T) runtime and memory consumption, even with backprop.So, my question is: Is JAX
scan
really inefficient like I described, esp for the case of backprop? Or if not, why not? How does it avoid the O(T^2) memory consumption in backprop? Is this some automatic optimization? How does it work?I already got some preliminary answer by @mattjj here: As I understood, it is efficient because
scan
has specific code for autodiff, which is implemented in terms of otherscan
s.If you implement
scan
naively usingwhile_loop
and slicing and dynamic slice updates, it would have the problems I described though, right?I actually tried to implement that, but I get:
So I cannot easily implement it?
If you need a dynamic ending condition, so there are no
xs
andlength
is unknown, how would you do that? That would need such a customscan
implementation?Somewhat related issue: #3106 on TensorArray equivalent. but it doesn't really answer my question here on efficiency
Original question also asked here on StackOverflow.