Closed pablo2909 closed 11 months ago
So typically one has to know the computation to be able to pick the dummy value -- in general you can't know what a safe value is. The typical pattern is loosely something this:
def true_fn(x):
return 0
def false_fn(x):
return 1 / x # something that explodes at zero
def f(x):
pred = x == 0
safe_x = jnp.where(pred, 1, x)
return lax.cond(pred, true_fn, false_fn, safe_x)
In this case, we've set safe_x
to a value that can be used on both branches. It is now safe to run jax.vmap(f)
.
To add a bit more colour: this "safe dummy value" is most important in two cases:
Hmm let me detail a bit more my case then, and maybe rephrase my question
No branch would produce an infinite loop or nan values, I would just like to avoid doing some heavy computation twice.
# Instead of
jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, heavy_computation_2, x))(X)
# I Do:
X1 = jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, no_computation, x))(X)
X2 = jax.vmap(lambda x: lax.cond(cond, no_computation, heavy_computation_2, x))(X)
combine(X1, X2)
Am I right in doing the second option to avoid extra computation ? Or am I misunderstanding something ? It feels that the first case will do computation twice while the second case only once, but I'm a bit unsure from your explanation now.
Edit: And the no_computation
branch returns a pytree of the same shape as heavy_computation_*
but filled with Nan values, which is how I detect them and proceed. But I feel this is not a good practice
Thanks for the help :)
Oh I just realized that both cases are equivalent in terms of computation.. right ?
Actually, all of your examples will still perform conditional computation, i.e. be efficient. The reason is that you've only vmap'd the argument x
. It is specifically when the predicate cond
is batched that a lax.cond
turns into a jnp.where
.
Sorry this example is closer to my actual use case
jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)
In that case, if I understand correctly, it will be converted to jnp.where
and both branched will be executed. And there's no way around it, right ?
Now that you pointed out that this happens only when cond is batched , it all makes more sense. I missed that important point all this time, thanks a lot for the clarification :)
Yup, pretty much.
There is one important exception that does come up occasionally: if you expect the entire predicate to be True
across the whole batch, or False
across the whole batch, then in-principle you could run just one branch. (For example, this comes up in some differential equation solvers: the "expensive branch" is making a numerical step, the "cheap branch" is to do nothing, and you want to keep iterating until every batch element has finished making steps. At the end you'll get a False
predicate across the whole batch, and from that point onwards only need to make the cheap evaluations until the end of your loop.)
In this case you can make use of a trick, namely eqx.internal.unvmap_{any, all}
, which consumes a batch-of-predicates and applies any
or all
down the batch dimension to return an unbatched single predicate.
Use that the wrong way, of course, and you end up getting the wrong output: each batch element now interacts with the others. But if you feel like you know what you're doing, and happen to have an example of the above use-case, then this can be a useful trick to have: lax.cond(eqxi.unvmap_any(pred), ...)
.
Ohh , that looks super useful thanks a lot for the tip ! :)
Hi,
I'm asking a question following the recent blog post where point 9 raises awareness on
lax.cond
withjax.vmap
. One of the advice is to use a dummy value, which I understand as such:If we do that, then there we are safe from doing unwanted evaluation I believe. However, what is a good dummy value ? I would have imagine
None
would be, but if we store it in an array it is converted as ajnp.nan
. Which could be fine, but then if nan arises during training it becomes hard to debug. And in my casetrue_f
andfalse_f
can return any real value, so it's not an option to use0.
or something similar.Any advice on that?
Otherwise thanks a lot Patrick for the blog post !