google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.83k stars 229 forks source link

Wrapping the ```init``` function inside ```jax.jit``` #778

Open ksmdnl opened 1 month ago

ksmdnl commented 1 month ago

I'm currently doing a runtime analysis of the attention matrix of a transformer. Specifically, I'd like to know how the time complexity behaves w.r.t. to the size of the attention matrix.

def model(x):
    net = factory(hidden_dim, num_layers=1)
    return net(*x)

def inference(fn, rng, params, x, mode="runtime"):
    start = time.timeit()
    _ = jax.block_until_ready(fn)(params, x)
    end = time.timeit()
    print(f"{mode}: {end - start} s")
    return end - start

def main():
    if args.single == 0:
        nb_nodes = np.arange(args.max_nb_node) + 1
    else:
        nb_nodes = [args.max_nb_node]
    runtimes = []
    rng = jax.random.PRNGKey(42)
    for nb_node in nb_nodes:
        print(f"Number of node: {nb_node}")
        net = hk.without_apply_rng(hk.transform(model))
        node_fts = jax.random.normal(rng, (batch_size, nb_node, hidden_dim))
        edge_fts = jax.random.normal(rng, (batch_size, nb_node, nb_node, hidden_dim))
        x = (node_fts, edge_fts)
        params = net.init(rng, x)
        apply_fn = jax.jit(net.apply)

        # compile time
        _ = inference(apply_fn, rng, params, x, mode="compile time")

        # execution time
        runtime = inference(apply_fn, rng, params, x, mode="execution time")
        runtimes.append(runtime)

For nb_node = 414 I'm getting an OOM error in the initialization (when performing the Einstein summation), which looks as follows:

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 54101757696 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  249.86MiB
              constant allocation:         0B
        maybe_live_out allocation:   50.39GiB
     preallocated temp allocation:         0B
                 total allocation:   50.63GiB
              total fragmentation:         0B (0.00%)

This is quite strange, since I'm using A100 80GB. However, when one jit the init funtion, there is no OOM error even after nb_node = 500. My question is, would this be a correct workaround given in this case?

tomhennigan commented 1 month ago

tl;dr - You should jit the init function to get a version of it that (1) uses as little memory as possible (2) runs quickly.

To add a bit more detail, jax.jit (through XLA) applies a number of optimizations to your program. Some of these might reduce the overall peak memory footprint required by the init program. For example, one optimization that XLA does is limiting the live range of arrays.

Lets consider the following JAX program:

def fn():
  a = some_big_array()
  b = other_big_array()
  c = a + b
  d = yet_another_big_array()
  e = c + d
  return e

XLA would be able to notice that a and b can be safely deleted (and as such their GPU memory woudl be freed) before you compute d:

def fn():
  a = some_big_array()
  b = other_big_array()
  c = a + b

  # XLA knows a/b aren't used again so it can release GPU memory for them.
  free_gpu_memory(a)
  free_gpu_memory(b)

  d = yet_another_big_array()
  e = c + d
  return e

If you wanted to debug further and figure out which arrays were still hanging around causing the 50GB allocation to fail, then JAX allows you to see a traceback for where arrays were created, which might help you understand this in more depth. To help you get started, something like the following might work:

import traceback

def print_live_arrays():
  for array in jax.live_arrays():
    print(array.shape, array.dtype)
    traceback.print_tb(array.traceback.as_python_traceback())
    print()

try:
  benchmark()
except RuntimeError as e:
  if 'RESOURCE_EXHAUSTED' in e:
    print_live_arrays()
  raise e

That said, even if you knew the root cause (e.g. which arrays were hanging around) the recommended fix is the same: use jax.jit and let XLA optimize this for you.