Open aespielberg opened 2 years ago
In fact, spamming the idempotent v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
a few times to the end of grid_op adds GBs more memory usage :disappointed:
XLA may choose to evaluate several or all branches of cond
or select
, but indeed one intent of exposing cond
separately from select
is to allow you to indicate when that would be expensive. Using cond
suggests that only one branch should be evaluated (should other XLA optimizations allow for it).
But XLA has no "batched" conditional construct. Today, when batching the predicate to a cond
using vmap
, jax will transform the cond
to a select
instead. From a quick glance through your example, it seems like this might be happening.
Could jax.checkpoint
serve you here, by working around having to store material arrays in the course of autodiff?
I see. Is there any reason why vmap'd conds are not supported? What is the logic there? I think this would be very important for a lot of people (and it would be great if these nuances were documented). I don't know the internals of the jax compiler, but I do know that other similar systems support some form of branching. Is there any way to add this capability to XLA or somehow jax?
Actually - since I could, if I were carefully about my code, manually batch everything, including conditionals, I see no reason why this couldn't be supported with vmap?
I did try jax.checkpoint
, but even having one checkpoint and splitting the loop into half balloons the runtime (I think by 4x IIRC, I'll have to double check, but that's surprising, since I would think as long as checkpoints weren't encompassing each other, that the maximum extra runtime this would incur is 2x). I'm actually On that note, adding cond's vs. where's also balloons the runtime by 2x...eventually these things start to really add up.
Often, one branch of lax.cond
is pretty much like a no-op, while the other branch involves a significant amount of computation. If lax.cond
indication is being ignored by vmap
, then I guess it will lead to both memory and computation overhead. Using vmap
might turn out to be more expensive than writing a for loop.
A loop may be possible in some cases, but not all, and I think that makes this problematic. And, if I can, I would like to make an argument for why I think somehow supporting cond
's in vmap
'd scenarios is so important.
The wonderful thing about vmap
, as advertised, is that anything should be easily vmappable. It's essentially the beauty of abstraction and modularity, as described. Someone writes a module, foo
, and if you want to parallelize it, it should be as easy as just calling vmap(foo)
. You don't need to understand foo
, it should just work. In this case, it seems like cond
, a very critical dataflow construct, has different behaviors whether vmap
'd or not. The problem, of course, is that people who use a module might not even know if cond is being used inside. This leads to functionally very different performance profiles when vmap'd or not that might require deep investigation of foo
, which could be a quite complex piece of code. I don't think this is good for the mission of vmap
.
In some cases (as I would argue here), vmap
appears to be the correct way to process such code, rather than a for loop, since as far as I understand, indexing into arrays via at
is runtime expensive in jax (correct me if I'm mistaken), and creating large, sparse one-hot masks would be prohibitively memory-expensive without support for sparse matrices.
We're in agreement here by and large. This is something that we've thought about improving before, whether at the JAX or XLA level. I can't find an open issue for it, so let's use this one for it.
Hi there! I was wondering what the JAX team's latest thinking is regarding the behavior of lax.cond when batched via vmap.
I find myself often running into the design pattern of conditionally branching into two subroutines, one expensive, and the other a "placeholder," for example, returning a dummy zero tensor.
@minqi – the thinking hasn't changed much since this issue was last active. Although there's a fundamental puzzle regarding whether/how to do better, for now we're still producing select
when we batch cond
's predicate.
I find this one of the biggest practical issues with jax -- Vmap
+jit
are great but in a lot of code this also necessitates use of cond
with them, which results in compute/memory issues, as described above,...and the time spent trying to work around those.
Btw does switch
suffer from this same problem when used with vmap
?
Switch is implemented in terms of cond
, so yes it has the same characteristics.
That's what I thought -- might be helpful to document that for switch
similarly to what's in cond
doc.: e.g" However, when transformed with vmap
to operate over a batch of predicates/indices, switch
is converted to select
"
Hi,
I was wondering what the status was on that ?
I face the following situation:
jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)
IIUC, i'll execute both branches in this case and I would rather not :)
This is still the case, as mentioned in the docstring of jax.lax.cond
.
I don't know of any active work to change this.
Correct me if I'm wrong but I think you can get around this by splitting the vmap dimension into a list of single pieces, jnp.split(to_vmap, to_vmap.shape[0])
, tree_map the list, then concat things back together how you'd like. It could get fancier with the various tree utils, but overall it avoids the vmap and essentially unrolls the batch apply into a sort of list comprehension.
Hey, thanks for the suggestion, I create a MWE that I think illustrates your idea. Let me know if that is not the case. This example batches the condition as well as the argument of lax.cond
function. It is 1000 matrix vector products.
import jax
from jax import lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random
X = jnp.arange(1000*4).reshape(1000, 4)
sigma = random.normal(random.PRNGKey(0), (1000,)) + 3.
A = random.normal(random.PRNGKey(1), (1000, 4))
def true_f(x):
return A@x
def false_f(x):
return -A@x
def f(sigma, x):
return lax.cond(sigma < 3, true_f, false_f, x)
@jax.jit
def F(SIGMA, X):
return jax.vmap(lambda sigma, x: f(sigma, x))(SIGMA,X)
@jax.jit
def splitF(SIGMA, X):
split_sigma = jnp.split(SIGMA, SIGMA.shape[0])
split_x = jnp.split(X, X.shape[0])
return jnp.stack(jtu.tree_map(lambda sigma, x: f(sigma[0], x[0]), split_sigma, split_x))
print(F(sigma, X))
print("-----------")
print(splitF(sigma, X))
The compilation of splitF
is significantly longer than F
. Intuitively I would have said that splitF
would have been longer to run, but I get the following performance test:
%timeit F(sigma, X).block_until_ready()
>>> 3.01 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit splitF(sigma, X).block_until_ready()
>>> 2.81 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I'm not sure if we can draw meaningful conclusions from this small example but it's a starting point.
@pablo2909 Yeah the compilation time would take longer but it should allow taking advantage of cond only computing one branch during run, so it depends what you’d like to take advantage of.
Here's an example that shows the differences for true and false functions with highly skewed compute times, 20 qr decomps vs a single elementwise multiply:
from time import time
import jax
from jax import numpy as jnp, jit
def true_fn(rng_key, x):
for _ in range(20):
x = jnp.linalg.qr(x)[0]
rng_key, subkey = jax.random.split(rng_key)
y = jax.random.uniform(subkey, (200, 200))
x = x + y
return x
def false_fn(_, x):
return x * 1.01
def main_fn(rng_key, x):
rng_key, subkey = jax.random.split(rng_key)
c = jax.random.choice(subkey, 2).astype(bool)
return jax.lax.cond(
c,
true_fn,
false_fn,
rng_key,
x,
)
@jit
def regular_vmap(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
return jax.vmap(main_fn)(rng_keys, xs)
@jit
def unrolled(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
x_out = [main_fn(k, x) for k, x in zip(rng_keys, xs)]
x_out = jnp.concatenate([jnp.expand_dims(x, 0) for x in x_out])
return x_out
rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))
# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)
t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)
n = 5
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = regular_vmap(subkey, x)
x = jax.block_until_ready(x)
print("regular_vmap", (time() - t) / n)
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = unrolled(subkey, x)
x = jax.block_until_ready(x)
print("list_vmap", (time() - t) / n)
The times on my machine are:
regular_vmap compile 3.4660181999206543
list_vmap compile 18.33980703353882
regular_vmap 0.79718918800354
list_vmap 0.22351880073547364
The unroll has several times longer compile time, but only about a quarter the run time, so if it's going to be run for many iterations the compile time would be worth it. In the case where true and false functions are both light, or can be computed simultaneously on a gpu, it might be worth it to just use vmap, you'd have to experiment.
Thanks a lot, I'm a bit surprised by your results. Your false_fn
is very cheap compared to true_fn
so I would have expected that vmapping would have been faster than essentially a compiled for
loop. I might be missing something.
To adapt it a bit more to my original question, I make true_fn
and false_fn
both equally costly and replace the list comprehension with a lax.scan
. The code looks like this and below it are the performances.
from time import time
import jax
from jax import numpy as jnp, jit
def true_fn(rng_key, x):
for _ in range(20):
x = jnp.linalg.qr(x)[0]
rng_key, subkey = jax.random.split(rng_key)
y = jax.random.uniform(subkey, (200, 200))
x = x + y
return x
def false_fn(rng_key, x):
for _ in range(20):
x = jnp.linalg.qr(x)[0]
rng_key, subkey = jax.random.split(rng_key)
y = jax.random.uniform(subkey, (200, 200))
x = x - y
return x
def main_fn(carry, rng_key_x):
rng_key, x = rng_key_x
rng_key, subkey = jax.random.split(rng_key)
c = jax.random.choice(subkey, 2).astype(bool)
return carry, jax.lax.cond(
c,
true_fn,
false_fn,
rng_key,
x,
)
@jit
def regular_vmap(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
return jax.vmap(main_fn, in_axes=(None, 0))(None, (rng_keys, xs))[1]
@jit
def unrolled(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
_, x_out = jax.lax.scan( main_fn, None ,(rng_keys, xs))
return x_out
rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))
# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)
t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)
n = 5
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = regular_vmap(subkey, x).block_until_ready()
print("regular_vmap", (time() - t) / n)
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = unrolled(subkey, x).block_until_ready()
print("list_vmap", (time() - t) / n)
>>> regular_vmap compile 7.062215089797974
>>> list_vmap compile 4.893687963485718
>>> regular_vmap 2.9101763725280763
>>> list_vmap 1.1428837776184082
Scanning over the inputs is a good idea, I think that was mentioned before in a similar issue, it sacrifices a bit of runtime for faster compile (you could of course also do this for the qr for loop)
Edit: Actually after testing this setup the scan-over-batch version runs faster than the unroll!
Unfortunately unrolling using scan or list comprehension only seems to consistently improve performance on cpu, not gpu or tpu, unless the batch size is very small and the branches are wildly unequal in compute. So I don't think there's a good solution to this problem without batched cond through XLA :(
I've been observing the same on gpu :/
I'm not familiar with triton at all but maybe there's a way to batch cond through pallas. I don't have enough of a need to look that deep into it though 😁
Hello, I just want to voice my support for having a "batchable cond". It would be very usefull to have for our models!
I have had the same issue as @minqi here. Just wanted to express that this is still an issue that people deal with
Because this hasn't been mentioned yet:
As far as I know, using vmap
and while_loop
can cause similar problems, since a batched while loop executes the loop body for every batch item until all batch items satisfy the termination condition (see https://github.com/google/jax/discussions/15954).
For computationally expensive loop body functions, it might be faster to use a for loop over all batch items instead.
Regarding the cond
discussion:
It would be nice to have the option to disable vmap completely, i.e. raise an Exception when jax tries to convert a vmap cond to a select. This could make debugging a lot easier (i.e. "where does this huge memory usage come from?" etc). It seems to me that most users here agree that the conversion of vmap cond to select is not wanted or even expected.
It might be worth mentioning that it is possible to use jax.custom_batching.sequential_vmap
.
@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?
@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?
It generates a jax.lax.map
, which bottoms out in a (rolled) XLA loop.
I tried using jax.lax.map
+ jax.lax.cond
in my program, but it appears to be significantly slower than jax.vmap
+ jax.numpy.where
on an NVIDIA GeForce RTX 3090. I want to express my support for the "batch cond" feature mentioned by others, as it would be highly valuable for my current work!
Indeed, it all depends on the program (ignoring possible compiler rewrites), i.e. what's being computed within each branch. The two approaches trade off total compute vs. parallelism.
For similar reasons, there are many possible implementations of "batched cond," all along this tradeoff curve.
I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.
Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.
Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from
lin = ....
until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line writtenv_out = jax.lax.cond ...
evaluates to False by definition, and so most of the expressions, including the v_out_gate's and their dependencies, shouldn't even need to be evaluated in the jitted function.Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don't want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101
Code to reproduce is below.