patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 134 forks source link

Unroll in eqx's internal `scan` #726

Open neel04 opened 4 months ago

neel04 commented 4 months ago

In the eqx.internal.scan do we not have a way to unroll the scan? I'm not sure what the constraint on the "checkpointed" scan since I'm not familiar with the algorithms used there, but I assumed people will want to use unroll in the lax styled scan?

patrick-kidger commented 4 months ago

I think it'd be pretty tricky to get both checkpointing and unrolling together. I'm afraid this probably isn't something we'll support.

neel04 commented 4 months ago

Out of curiosity, how much work will need to be put to do this? Surely (in theory) checkpointing is unrollable 🤔

I'm getting a lot of performance hits that I wish to recover - XLA needs to unroll to effectively do fusion and LHS scheduling so I'd prefer an unrolled loop...

patrick-kidger commented 4 months ago

There opens up a question of whether you want to do unroll-of-checkpointed-scan (hard) or checkpoint-of-unrolled-scan (easy). At least for the latter you should be able to implement this by putting a small for loop inside the body function of your scan.

(I sympathise, XLA's implementation of loops and control flow has never been very strong.)

neel04 commented 4 months ago

To avoid any X-Y, my use case is effectively the same as Universal Transformers where I want to recursively apply a block of layers n times. n can be treated as a fixed constant for now.

So AIUI, I'd need to do a unroll for a checkpointed scan if I want to keep my memory consumption low and trade-off with computation. I'm not sure if there's a better way to accomplish this - eqx.internal.scan has worked wonderfully for me, but the overhead is substantial, and I wanna optimize this.

patrick-kidger commented 4 months ago

Do you know how much of the overhead is coming from XLA's lack of fusion/etc (compare to a normal for loop), and how much is coming from the overhead of checkpointing (which does require extra recomputation; compare to a lax.scan)?

neel04 commented 4 months ago

I'm not sure - using a vanilla lax.scan takes too much memory so I have to slice my batch_size by 4. I can try using a remat policy on the body_fun of the scan to try and recover some of the memory problems.

A normal for loop seems to be slightly worse, but slightly faster if I do n=1 (i.e a normal transformer forward pass, no UT). I think the checkpointed internal.scan seems to be a good tradeoff but XLA does seem to be having trouble with this. Specifically, I'm unable to do unroll > 1 for the scan which means I'm leaving a lot of optimization ability on the table...

patrick-kidger commented 4 months ago

I'd suggest avoiding jax.checkpoint here to be sure we're comparing the things we want to. Running all the comparisons with a smaller batch size to fit in memory sounds good.

neel04 commented 4 months ago

Throughput measured after: 40 minutes in tokens/s

Method Throughput n
lax.scan 120k 3
lax.scan, unroll=2 110k 3
checkpointed scan 155k 3
vanilla for-loop 162k 3

Hm so the overhead isn't that much here...

patrick-kidger commented 4 months ago

Oh this is super interesting! In particular the fact that checkpointed scans are faster than uncheckpointed scans. Probably something to do with recomputation being cheaper than moving things around in memory.

I think if you still want to try tweaking this then I'd suggest trying a little for-loop in the body function of the checkpointed scan; this may offer some naive opportunities for fusion without really getting in the way of anything else. But perhaps this is already "mission accomplished"... !

neel04 commented 4 months ago

Thanks! but what do you mean by "for-loop in the body function"? the body function doesn't have an iterative element to it - its just passing the carry through a bunch of layers...

patrick-kidger commented 4 months ago

Something like this:

def fun(carry, x):
    ... # do stuff
    return carry2, y

def fun2(carry, x):
    y = []
    for xi in x:
        carry, yi = body_fun(carry, xi)
        y.append(yi)
    return carry, jtu.tree_map(lambda *x: jnp.stack(x), *y)

eqxi.scan(fun, ..., xs=xs.reshape(xs.shape[0] // 5, 5, *xs.shape[1:]))

in which you manually unroll a few adjcent steps.

(This is all that lax.scan(..., unroll=5) is doing for you under the hood.)