jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.02k stars 2.75k forks source link

vmap of cond's predicate results in select, leading to unexpected compute/memory use #8409

Open aespielberg opened 2 years ago

aespielberg commented 2 years ago

I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.

Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.

Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from lin = .... until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line written v_out = jax.lax.cond ... evaluates to False by definition, and so most of the expressions, including the v_out_gate's and their dependencies, shouldn't even need to be evaluated in the jitted function.

Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don't want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101

Code to reproduce is below.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import jax.nn as jnn
import jax.lax as jlax
import timeit
import jax

dim = 2
n_grid = 128
dt = 1e-3
gravity = 3.8

def allocate_arrays():
  global grid_m_in, grid_v_in, grid_v_out, loss, index_array
  grid_m_in = jnp.ones((n_grid, n_grid))
  grid_v_in = jnp.zeros((n_grid, n_grid, dim))
  grid_v_out = jnp.zeros((n_grid, n_grid, dim))

  index_array = np.zeros((n_grid, n_grid, dim))

  for i in range(n_grid):
    for j in range(n_grid):
      index_array[i, j] = np.array([i, j])

  index_array = jnp.array(index_array)

def grid_op(grid_v_in, grid_m_in, index_tuple):
  bound = 0
  coeff = 0.5

  i = index_tuple[0]
  j = index_tuple[1]

  normal = jnp.array([0., 1.])

  inv_m = 1 / (grid_m_in + 1e-10)
  v_out = jnp.expand_dims(inv_m, -1) * grid_v_in
  v_out -= dt * gravity * jnp.array([0., 1.])

  v_out = jax.lax.cond(jnp.logical_and(i < bound, v_out[0] < 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)

  v_out = jax.lax.cond(jnp.logical_and(i > n_grid - bound, v_out[0] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)

  lin = (v_out.transpose() @ normal)

  vit = v_out - lin * normal
  lit = jnp.linalg.norm(vit + 1e-10)  + 1e-10

  v_out_gate_2 = jax.lax.cond(lit + coeff * lin <= 0, lambda _: jnp.zeros_like(v_out), lambda _: (1 + coeff * lin / lit) * vit, operand=None)
  v_out_gate_1 = jax.lax.cond(lin < 0, lambda _: v_out_gate_2, lambda _: jnp.zeros_like(v_out), operand=None)
  v_out = jax.lax.cond(jnp.logical_and(j < bound, v_out[1] < 0), lambda _: v_out_gate_1, lambda _: v_out, operand=None)          
  v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)

  return v_out

go_j = jit(vmap(vmap(grid_op)))

def advance2(t, args):
  grid_v_in = args[0]
  grid_m_in = args[1]
  index_array = args[2]
  grid_v_in = go_j(grid_v_in, grid_m_in, index_array)

  return grid_v_in, grid_m_in, index_array

def advance(t, args):
  x = args[0]
  v = args[1]
  C = args[2]
  F = args[3]
  x, v, C, F = p1_j(x, v, C, F, actuator_id)

  return x, v, C, F

a = jit(advance)

def forward2(grid_v_in, grid_m_in, index_array):
  grid_v_in, grid_m_in, index_array = jlax.fori_loop(0, 4000, advance2, (grid_v_in, grid_m_in, index_array))
  return jnp.mean(grid_v_in)

def main():
# initialization
  allocate_arrays()

  f2 = jit(forward2)
  forward_grad2 = jit(grad(forward2))

  number = 10

  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)

if __name__ == "__main__":
  main()
aespielberg commented 2 years ago

In fact, spamming the idempotent v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None) a few times to the end of grid_op adds GBs more memory usage :disappointed:

froystig commented 2 years ago

XLA may choose to evaluate several or all branches of cond or select, but indeed one intent of exposing cond separately from select is to allow you to indicate when that would be expensive. Using cond suggests that only one branch should be evaluated (should other XLA optimizations allow for it).

But XLA has no "batched" conditional construct. Today, when batching the predicate to a cond using vmap, jax will transform the cond to a select instead. From a quick glance through your example, it seems like this might be happening.

Could jax.checkpoint serve you here, by working around having to store material arrays in the course of autodiff?

aespielberg commented 2 years ago

I see. Is there any reason why vmap'd conds are not supported? What is the logic there? I think this would be very important for a lot of people (and it would be great if these nuances were documented). I don't know the internals of the jax compiler, but I do know that other similar systems support some form of branching. Is there any way to add this capability to XLA or somehow jax?

Actually - since I could, if I were carefully about my code, manually batch everything, including conditionals, I see no reason why this couldn't be supported with vmap?

I did try jax.checkpoint, but even having one checkpoint and splitting the loop into half balloons the runtime (I think by 4x IIRC, I'll have to double check, but that's surprising, since I would think as long as checkpoints weren't encompassing each other, that the maximum extra runtime this would incur is 2x). I'm actually On that note, adding cond's vs. where's also balloons the runtime by 2x...eventually these things start to really add up.

shailesh1729 commented 2 years ago

Often, one branch of lax.cond is pretty much like a no-op, while the other branch involves a significant amount of computation. If lax.cond indication is being ignored by vmap, then I guess it will lead to both memory and computation overhead. Using vmap might turn out to be more expensive than writing a for loop.

aespielberg commented 2 years ago

A loop may be possible in some cases, but not all, and I think that makes this problematic. And, if I can, I would like to make an argument for why I think somehow supporting cond's in vmap'd scenarios is so important.

The wonderful thing about vmap, as advertised, is that anything should be easily vmappable. It's essentially the beauty of abstraction and modularity, as described. Someone writes a module, foo, and if you want to parallelize it, it should be as easy as just calling vmap(foo). You don't need to understand foo, it should just work. In this case, it seems like cond, a very critical dataflow construct, has different behaviors whether vmap'd or not. The problem, of course, is that people who use a module might not even know if cond is being used inside. This leads to functionally very different performance profiles when vmap'd or not that might require deep investigation of foo, which could be a quite complex piece of code. I don't think this is good for the mission of vmap.

In some cases (as I would argue here), vmap appears to be the correct way to process such code, rather than a for loop, since as far as I understand, indexing into arrays via at is runtime expensive in jax (correct me if I'm mistaken), and creating large, sparse one-hot masks would be prohibitively memory-expensive without support for sparse matrices.

froystig commented 2 years ago

We're in agreement here by and large. This is something that we've thought about improving before, whether at the JAX or XLA level. I can't find an open issue for it, so let's use this one for it.

minqi commented 1 year ago

Hi there! I was wondering what the JAX team's latest thinking is regarding the behavior of lax.cond when batched via vmap.

I find myself often running into the design pattern of conditionally branching into two subroutines, one expensive, and the other a "placeholder," for example, returning a dummy zero tensor.

froystig commented 1 year ago

@minqi – the thinking hasn't changed much since this issue was last active. Although there's a fundamental puzzle regarding whether/how to do better, for now we're still producing select when we batch cond's predicate.

HHalva commented 1 year ago

I find this one of the biggest practical issues with jax -- Vmap+jit are great but in a lot of code this also necessitates use of cond with them, which results in compute/memory issues, as described above,...and the time spent trying to work around those.

HHalva commented 1 year ago

Btw does switch suffer from this same problem when used with vmap?

jakevdp commented 1 year ago

Switch is implemented in terms of cond, so yes it has the same characteristics.

HHalva commented 1 year ago

That's what I thought -- might be helpful to document that for switch similarly to what's in cond doc.: e.g" However, when transformed with vmap to operate over a batch of predicates/indices, switch is converted to select"

pablo2909 commented 10 months ago

Hi,

I was wondering what the status was on that ?

I face the following situation:

jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)

IIUC, i'll execute both branches in this case and I would rather not :)

jakevdp commented 10 months ago

This is still the case, as mentioned in the docstring of jax.lax.cond.

jakevdp commented 10 months ago

I don't know of any active work to change this.

evanatyourservice commented 10 months ago

Correct me if I'm wrong but I think you can get around this by splitting the vmap dimension into a list of single pieces, jnp.split(to_vmap, to_vmap.shape[0]), tree_map the list, then concat things back together how you'd like. It could get fancier with the various tree utils, but overall it avoids the vmap and essentially unrolls the batch apply into a sort of list comprehension.

pablo2909 commented 10 months ago

Hey, thanks for the suggestion, I create a MWE that I think illustrates your idea. Let me know if that is not the case. This example batches the condition as well as the argument of lax.cond function. It is 1000 matrix vector products.

import jax
from jax import lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random

X = jnp.arange(1000*4).reshape(1000, 4)
sigma = random.normal(random.PRNGKey(0), (1000,)) + 3.
A = random.normal(random.PRNGKey(1), (1000, 4))

def true_f(x):
    return A@x

def false_f(x):
    return -A@x

def f(sigma, x):
    return lax.cond(sigma < 3, true_f, false_f, x)

@jax.jit
def F(SIGMA, X):
    return jax.vmap(lambda sigma, x: f(sigma, x))(SIGMA,X)

@jax.jit
def splitF(SIGMA, X):
    split_sigma = jnp.split(SIGMA, SIGMA.shape[0])
    split_x = jnp.split(X, X.shape[0])
    return jnp.stack(jtu.tree_map(lambda sigma, x: f(sigma[0], x[0]), split_sigma, split_x))

print(F(sigma, X))
print("-----------")
print(splitF(sigma, X))

The compilation of splitF is significantly longer than F. Intuitively I would have said that splitF would have been longer to run, but I get the following performance test:

%timeit F(sigma, X).block_until_ready()
>>> 3.01 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit splitF(sigma, X).block_until_ready()
>>> 2.81 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I'm not sure if we can draw meaningful conclusions from this small example but it's a starting point.

evanatyourservice commented 10 months ago

@pablo2909 Yeah the compilation time would take longer but it should allow taking advantage of cond only computing one branch during run, so it depends what you’d like to take advantage of.

evanatyourservice commented 10 months ago

Here's an example that shows the differences for true and false functions with highly skewed compute times, 20 qr decomps vs a single elementwise multiply:

from time import time
import jax
from jax import numpy as jnp, jit

def true_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x + y
    return x

def false_fn(_, x):
    return x * 1.01

def main_fn(rng_key, x):
    rng_key, subkey = jax.random.split(rng_key)
    c = jax.random.choice(subkey, 2).astype(bool)
    return jax.lax.cond(
        c,
        true_fn,
        false_fn,
        rng_key,
        x,
    )

@jit
def regular_vmap(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    return jax.vmap(main_fn)(rng_keys, xs)

@jit
def unrolled(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    x_out = [main_fn(k, x) for k, x in zip(rng_keys, xs)]
    x_out = jnp.concatenate([jnp.expand_dims(x, 0) for x in x_out])
    return x_out

rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))

# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)

t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)

n = 5

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = regular_vmap(subkey, x)
x = jax.block_until_ready(x)
print("regular_vmap", (time() - t) / n)

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = unrolled(subkey, x)
x = jax.block_until_ready(x)
print("list_vmap", (time() - t) / n)

The times on my machine are:

regular_vmap compile 3.4660181999206543
list_vmap compile 18.33980703353882
regular_vmap 0.79718918800354
list_vmap 0.22351880073547364

The unroll has several times longer compile time, but only about a quarter the run time, so if it's going to be run for many iterations the compile time would be worth it. In the case where true and false functions are both light, or can be computed simultaneously on a gpu, it might be worth it to just use vmap, you'd have to experiment.

pablo2909 commented 10 months ago

Thanks a lot, I'm a bit surprised by your results. Your false_fn is very cheap compared to true_fn so I would have expected that vmapping would have been faster than essentially a compiled forloop. I might be missing something.

To adapt it a bit more to my original question, I make true_fn and false_fn both equally costly and replace the list comprehension with a lax.scan. The code looks like this and below it are the performances.

from time import time
import jax
from jax import numpy as jnp, jit

def true_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x + y
    return x

def false_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x - y
    return x

def main_fn(carry, rng_key_x):
    rng_key, x = rng_key_x
    rng_key, subkey = jax.random.split(rng_key)
    c = jax.random.choice(subkey, 2).astype(bool)
    return carry, jax.lax.cond(
        c,
        true_fn,
        false_fn,
        rng_key,
        x,
    )

@jit
def regular_vmap(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    return jax.vmap(main_fn, in_axes=(None, 0))(None, (rng_keys, xs))[1]

@jit
def unrolled(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    _, x_out = jax.lax.scan( main_fn, None ,(rng_keys, xs))
    return x_out

rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))

# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)

t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)

n = 5

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = regular_vmap(subkey, x).block_until_ready()
print("regular_vmap", (time() - t) / n)

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = unrolled(subkey, x).block_until_ready()
print("list_vmap", (time() - t) / n)
>>> regular_vmap compile 7.062215089797974
>>> list_vmap compile 4.893687963485718
>>> regular_vmap 2.9101763725280763
>>> list_vmap 1.1428837776184082
evanatyourservice commented 10 months ago

Scanning over the inputs is a good idea, I think that was mentioned before in a similar issue, it sacrifices a bit of runtime for faster compile (you could of course also do this for the qr for loop)

Edit: Actually after testing this setup the scan-over-batch version runs faster than the unroll!

evanatyourservice commented 9 months ago

Unfortunately unrolling using scan or list comprehension only seems to consistently improve performance on cpu, not gpu or tpu, unless the batch size is very small and the branches are wildly unequal in compute. So I don't think there's a good solution to this problem without batched cond through XLA :(

pablo2909 commented 9 months ago

I've been observing the same on gpu :/

evanatyourservice commented 9 months ago

I'm not familiar with triton at all but maybe there's a way to batch cond through pallas. I don't have enough of a need to look that deep into it though 😁

kerupp commented 7 months ago

Hello, I just want to voice my support for having a "batchable cond". It would be very usefull to have for our models!

nmonette commented 2 months ago

I have had the same issue as @minqi here. Just wanted to express that this is still an issue that people deal with

inversecrime commented 3 weeks ago

Because this hasn't been mentioned yet:

As far as I know, using vmap and while_loop can cause similar problems, since a batched while loop executes the loop body for every batch item until all batch items satisfy the termination condition (see https://github.com/google/jax/discussions/15954). For computationally expensive loop body functions, it might be faster to use a for loop over all batch items instead.

Regarding the cond discussion:

It would be nice to have the option to disable vmap completely, i.e. raise an Exception when jax tries to convert a vmap cond to a select. This could make debugging a lot easier (i.e. "where does this huge memory usage come from?" etc). It seems to me that most users here agree that the conversion of vmap cond to select is not wanted or even expected.

inversecrime commented 3 weeks ago

It might be worth mentioning that it is possible to use jax.custom_batching.sequential_vmap.

evanatyourservice commented 3 weeks ago

@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?

froystig commented 3 weeks ago

@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?

It generates a jax.lax.map, which bottoms out in a (rolled) XLA loop.

guyuntian commented 2 weeks ago

I tried using jax.lax.map + jax.lax.cond in my program, but it appears to be significantly slower than jax.vmap + jax.numpy.where on an NVIDIA GeForce RTX 3090. I want to express my support for the "batch cond" feature mentioned by others, as it would be highly valuable for my current work!

froystig commented 2 weeks ago

Indeed, it all depends on the program (ignoring possible compiler rewrites), i.e. what's being computed within each branch. The two approaches trade off total compute vs. parallelism.

For similar reasons, there are many possible implementations of "batched cond," all along this tradeoff curve.