jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

O(1) forward computation requires potentially unbounded time to compute gradient #8239

Closed patrick-kidger closed 2 years ago

patrick-kidger commented 2 years ago

So this is a fun one.

The context here is that I'm implementing #5642, which I'm thinking of as a reverse-mode autodifferentiable while loop subject to a maximum number of iterations.

The good news is that the forward pass works perfectly (correct asymptotics; handles issues to do with vmap and in-place updates c.f. #8192). However the backward pass can be arbitrarily expensive. I've already tried staring at the jaxpr without anything jumping out at me as being obviously wrong. I'm not completely sure whether to regard this as a bug in JAX (maybe XLA), or if this is something I can work around on the user side of things.

The following is most minimal of MWE I've been able to put together (e.g. this version won't vmap efficiently; I've cut the special handling of that out). Even so it's more of a "moderately sized working example".

First of all, here is the (simplified) code for bounded_while_loop:

import jax
import jax.lax as lax

def bounded_while_loop(cond_fun, body_fun, init_val, max_steps):
    """API as `lax.while_loop`, except that it takes an integer `max_steps` argument."""
    if not isinstance(max_steps, int) or max_steps < 0:
        raise ValueError("max_steps must be a non-negative integer")
    if max_steps == 0:
        return init_val
    if max_steps & (max_steps - 1) != 0:
        raise ValueError("max_steps must be a power of two")

    init_data = (cond_fun(init_val), init_val)
    _, val = _while_loop(cond_fun, body_fun, init_data, max_steps)
    return val

def _while_loop(cond_fun, body_fun, data, max_steps):
    if max_steps == 1:
        pred, val = data
        new_val = body_fun(val)
        keep = lambda a, b: lax.select(pred, a, b)
        new_val = jax.tree_map(keep, new_val, val)
        return cond_fun(new_val), new_val
    else:

        def _call(_data):
            return _while_loop(cond_fun, body_fun, _data, max_steps // 2)

        def _scan_fn(_data, _):
            _pred, _ = _data
            return lax.cond(_pred, _call, lambda x: x, _data), None

        return lax.scan(_scan_fn, data, xs=None, length=2)[0]

Then the test harness:

import functools as ft
import jax
import jax.experimental.stax as stax
import jax.numpy as jnp
import jax.random as jrandom
import time

_key = jrandom.PRNGKey(0)
_init, _apply = stax.serial(stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1))
expensive_fn = ft.partial(_apply, _init(_key, (1,))[1])

def cond_fun(val):
    x, step = val
    return step < 8

def body_fun(val):
    x, step = val
    return (expensive_fn(x), step + 1)

def timed(fn):
    def timer(*a, **kw):
        start = time.time()
        fn(*a, **kw).block_until_ready()
        end = time.time()
        print(end - start)
    return timer

@timed
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def f(val, max_steps):
    return jnp.sum(bounded_while_loop(cond_fun, body_fun, (val, 0), max_steps)[0])

val = jnp.array([1.])
f(val, 8)
f(val, 8)  # 0.037744998931884766
f(val, 16)
f(val, 16)  # 0.05941605567932129
f(val, 32)
f(val, 32)  # 0.10239911079406738

As you can see, the runtime (of this gradient operation) increases proportional to the max_steps bound. This is despite the forward operation running in constant time (just comment out the jax.grad and rerun). Indeed the overall number of steps taken in the while loop is always exactly 8, by the choice of cond_fun.

(Incidentally compile times are exponential in the depth of nested scans because of #8193, #8184, but based on the commentary in #8184 it sounds like a fix for that issue might be on the horizon.)

These times were obtained on the CPU. I've verified that the same behaviour also occurs on the GPU (with a more expensive expensive_fn).

This kind of adaptive computation is the sort of thing for which I'm still usually reaching towards PyTorch, what with the static-graph requirement of XLA. I'd love to get the above working, as IMO it's the one arena in which JAX still hasn't caught up.

mattjj commented 2 years ago

Thanks for raising this! I haven't looked at it yet, but I wanted to ask you about this comment:

This kind of adaptive computation is the sort of thing for which I'm still usually reaching towards PyTorch, what with the static-graph requirement of XLA. I'd love to get the above working, as IMO it's the one arena in which JAX still hasn't caught up.

As you probably know, you can use JAX without any static graph requirements: just use jax.numpy and jax.grad. Optionally you can apply jax.jit to subroutines where you don't mind opting into some requirements, but you don't need to use it at all. Sometimes we explain this as "write your code without any jax.jit, then once it's working you can optimize performance by applying jax.jit to the biggest subroutines you can.

Is following that unconstrained approach not performant enough for your workload? If so, can you share some kind of benchmark?

patrick-kidger commented 2 years ago

So what you describe is actually what I'm doing at the moment -- I basically have a Python while loop and jax.jit the interior. The various limitations of this I've been bumping up against are:

Interpreter overhead The first problem is the one you indicated in your comment -- performance.

A normal pattern is to stack jit-grad-vmap. This incurs very little runtime overhead.

Switching things around to grad-vmap-jit means frequently passing through JAX internals. This gets really expensive. Here's a couple of examples.

First, the flame graph for an operation doing just vmap-jit:

image

The regions in blue are the times that XLA is actually being executed. Everything else is just overhead. I've also highlighted a green region: this is the cost of crossing a JIT API boundary, in which a bunch of complicated objects are partitioned into trace/static. It's so large because this is happening repeatedly inside a while loop. [For this example I don't think the vmap changes anything -- it's just the cost of passing back-and-forth through jax.jit so many times.]

[I realise blue/green may an issue if you're colour-blind -- I can look into how to re-colour things appropriately if so; let me know.]

Second, an operation doing grad-vmap-jit:

image

I'm a little less certain about this one (I'm not quite as familiar with the internals of jax.grad), but I think in this case the blue region is the forward pass, and the purple region is the backward pass. Everything else looks to be things like jaxpr manipulation, i.e. interpreter overhead that wouldn't be there if if the operation could be jit'd.

Developer ergonomics Without exception, every JAX operation has to be jitted. Op-by-op mode incurs simply too much overhead. (A fact arrived at by staring at flame graphs of the call stack, much like those above.) This means writing code like

@jax.jit
def _jit_lt(a, b):
    return a < b

while _jit_lt(a, b):
    ...

with lots of little mini-jit-functions every time a JAX array must be interacted with.

The above is an actual example from my code -- introducing this JIT produced a measurable improvement in performance.

User ergonomics I'm developing a software library. It's pretty frustrating for a user to be told that they can't JIT mylibrary.myfunction, and basically subjects the user to the same thing as above: at minimum one has to write a "before mylibrary.myfunction JIT" and an "after mylibrary.myfunction JIT".

Fixing this is actually my primary concern, as this fundamentally breaks composability with the rest of the JAX ecosystem.

In-place updates As I recall (been a while since I figured out how to work around this issue), it's not possible to make in-place updates to the same buffer in different jax.jit regions; passing back to Python-land forces a copy.

I'm aware of donate_argnums but this hasn't seemed to help -- possibly things are somehow too complex for the compiler to figure out? And either way donate_argnums isn't yet supported on the CPU (spitting out a bunch of warnings instead).

This one's pretty important for efficiency purposes.

What to trace/static The main routine I'm writing takes an instance of jax.tree_utils.Partial as an argument. This is some parameterised function specified by the user of the software library. The parameterisation may include a mixture of some things worth JIT-tracing and some things worth JIT-static'ing.

When the jax.jit call happens inside the software library, it falls on the developer to make an opionated choice on what to JIT. (e.g. "trace all JAX arrays"). A user may want something slightly different.

This one can obviously be worked around in various ways -- have extra arguments for the "JIT-trace args" and the "JIT-static args", or let a user specify some partition function -- but that's less elegant at the API level, not to mention the library internals now need to pipe multiple argments around the place.

Overall it'd be preferential to avoid making this the library's problem at all. Just make it possible for the user jax.jit their entire operation, mylibrary.myfunction and all.

/walloftext!

patrick-kidger commented 2 years ago

For the sake of anyone landing here in the future: I've managed to resolve this by adding a jax.checkpoint decorator to the _scan_fn.

I'm not 100% certain exactly what's going wrong in the original code but I think it's something to do with the way the forward pass is stored until the backward pass, so forcing a different strategy like this seems to help.

With this fix, computing the gradient seems to have a runtime of O(steps taken) + O(log(max_steps)). Ideally we'd only have the O(steps taken) part, of course, but this still seems to be good enough for practical purposes.

mattjj commented 2 years ago

Finally getting back to this!

The right solution is to use jax.checkpoint as you are.

The issue here is that:

  1. the gradient of a cond produces one cond for the forward pass and one cond for the backward pass, where as usual the forward pass computation (without jax.checkpoint) computes "residual" values during the forward pass which are consumed by the backward pass;
  2. jax.lax.cond, like the HLO Conditional, requires all branches to have the same output type.

In particular, as a consequence of these two facts, if we have one branch which when differentiated requires no residual outputs, and another branch which requires a lot of residual outputs, to form a valid cond/Conditional we need to join the types of these two branches by adding some dummy residuals (arrays of zeros) to the no-residuals branch. These are the broadcast_in_dim applications we see when we run this version of the code:

import jax
import jax.lax as lax

def bounded_while_loop(cond_fun, body_fun, init_val, max_steps):
    """API as `lax.while_loop`, except that it takes an integer `max_steps` argument."""
    if not isinstance(max_steps, int) or max_steps < 0:
        raise ValueError("max_steps must be a non-negative integer")
    if max_steps == 0:
        return init_val
    if max_steps & (max_steps - 1) != 0:
        raise ValueError("max_steps must be a power of two")

    init_data = (cond_fun(init_val), init_val)
    _, val = _while_loop(cond_fun, body_fun, init_data, max_steps)
    return val

def _while_loop(cond_fun, body_fun, data, max_steps):
    if max_steps == 1:
        pred, val = data
        new_val = body_fun(val)
        keep = lambda a, b: lax.select(pred, a, b)
        new_val = jax.tree_map(keep, new_val, val)
        return cond_fun(new_val), new_val
    else:

        def _call(_data):
            return _while_loop(cond_fun, body_fun, _data, max_steps // 2)

        @jax.checkpoint
        def _scan_fn(_data, _):
            _pred, _ = _data
            return lax.cond(_pred, _call, lambda x: x, _data), None

        return lax.scan(_scan_fn, data, xs=None, length=2)[0]

import functools as ft
import jax
import jax.experimental.stax as stax
import jax.numpy as jnp
import jax.random as jrandom
import time

_key = jrandom.PRNGKey(0)
_init, _apply = stax.serial(stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1))
expensive_fn = ft.partial(_apply, _init(_key, (1,))[1])

def cond_fun(val):
    x, step = val
    return step < 8

def body_fun(val):
    x, step = val
    return (expensive_fn(x), step + 1)

timings = []

def timed(fn):
    def timer(*a, **kw):
        start = time.time()
        fn(*a, **kw).block_until_ready()
        end = time.time()
        timings.append(end - start)
    return timer

# @timed
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def f(val, max_steps):
    return jnp.sum(bounded_while_loop(cond_fun, body_fun, (val, 0), max_steps)[0])

val = jnp.array([1.])

import sys
n = int(sys.argv[1])
print(jax.make_jaxpr(lambda x: f(x, n))(val))
python 8239.py 2  # or 4, 8, 16, whatever

These dummy-constructing broadcasts do enough extra work to explain the behavior you observed.

Because jax.checkpoint means the forward pass doesn't need to produce any residuals, this extra work doesn't need to happen. (Actually, the broadcast_in_dims aren't completely removed by jax.checkpoint, but they're all moved to the backward pass to be right next to their consumers, so we believe it's easier for XLA to optimize away this extra work. We could perform an optimization like this at the JAX level, making sure that remat-of-cond turns into a cond-of-remat, which would make these values not be created in the first place.)

The best fix actually involves dynamic shapes, or something similar: we'd like to create residuals which might have zero size when we take the no-op branch of the cond. Equivalently we'd like optionals or sum types which don't involve any runtime work. That's out of scope for JAX as it exists today though.

I'm going to close this issue because this is intended behavior for jax.lax.cond today, and jax.checkpoint is the right way to get the computation you want, though in a glorious dynamic-shape future we could make this smarter automatically.

Thanks to @axch for help pair-debugging this.

patrick-kidger commented 2 years ago

Sounds good! Thank you for the write-up. (Crossing my fingers for the glorious dynamic-shape future.)