patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

What is a good dummy value? #588

Closed pablo2909 closed 11 months ago

pablo2909 commented 11 months ago

Hi,

I'm asking a question following the recent blog post where point 9 raises awareness on lax.cond with jax.vmap. One of the advice is to use a dummy value, which I understand as such:

# Instead of

jax.vmap(lambda x: lax.cond(cond, true_f, false_f, x))(X)

# Do:
true_x = jax.vmap(lambda x: lax.cond(cond, true_f, dummy_f, x))(X)
false_x = jax.vmap(lambda x: lax.cond(cond, dummy_f, false_f, x))(X)
combine(true_x, false_x)

If we do that, then there we are safe from doing unwanted evaluation I believe. However, what is a good dummy value ? I would have imagine None would be, but if we store it in an array it is converted as a jnp.nan. Which could be fine, but then if nan arises during training it becomes hard to debug. And in my case true_f and false_f can return any real value, so it's not an option to use 0. or something similar.

Any advice on that?

Otherwise thanks a lot Patrick for the blog post !

patrick-kidger commented 11 months ago

So typically one has to know the computation to be able to pick the dummy value -- in general you can't know what a safe value is. The typical pattern is loosely something this:

def true_fn(x):
    return 0

def false_fn(x):
    return 1 / x  # something that explodes at zero

def f(x):
    pred = x == 0
    safe_x = jnp.where(pred, 1, x)
    return lax.cond(pred, true_fn, false_fn, safe_x)

In this case, we've set safe_x to a value that can be used on both branches. It is now safe to run jax.vmap(f).

To add a bit more colour: this "safe dummy value" is most important in two cases:

pablo2909 commented 11 months ago

Hmm let me detail a bit more my case then, and maybe rephrase my question

No branch would produce an infinite loop or nan values, I would just like to avoid doing some heavy computation twice.

# Instead of

jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, heavy_computation_2, x))(X)

# I Do:
X1 = jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, no_computation, x))(X)
X2 = jax.vmap(lambda x: lax.cond(cond, no_computation, heavy_computation_2, x))(X)
combine(X1, X2)

Am I right in doing the second option to avoid extra computation ? Or am I misunderstanding something ? It feels that the first case will do computation twice while the second case only once, but I'm a bit unsure from your explanation now.

Edit: And the no_computation branch returns a pytree of the same shape as heavy_computation_* but filled with Nan values, which is how I detect them and proceed. But I feel this is not a good practice

Thanks for the help :)

pablo2909 commented 11 months ago

Oh I just realized that both cases are equivalent in terms of computation.. right ?

patrick-kidger commented 11 months ago

Actually, all of your examples will still perform conditional computation, i.e. be efficient. The reason is that you've only vmap'd the argument x. It is specifically when the predicate cond is batched that a lax.cond turns into a jnp.where.

pablo2909 commented 11 months ago

Sorry this example is closer to my actual use case

jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)

In that case, if I understand correctly, it will be converted to jnp.where and both branched will be executed. And there's no way around it, right ?

Now that you pointed out that this happens only when cond is batched , it all makes more sense. I missed that important point all this time, thanks a lot for the clarification :)

patrick-kidger commented 11 months ago

Yup, pretty much.

There is one important exception that does come up occasionally: if you expect the entire predicate to be True across the whole batch, or False across the whole batch, then in-principle you could run just one branch. (For example, this comes up in some differential equation solvers: the "expensive branch" is making a numerical step, the "cheap branch" is to do nothing, and you want to keep iterating until every batch element has finished making steps. At the end you'll get a False predicate across the whole batch, and from that point onwards only need to make the cheap evaluations until the end of your loop.)

In this case you can make use of a trick, namely eqx.internal.unvmap_{any, all}, which consumes a batch-of-predicates and applies any or all down the batch dimension to return an unbatched single predicate. Use that the wrong way, of course, and you end up getting the wrong output: each batch element now interacts with the others. But if you feel like you know what you're doing, and happen to have an example of the above use-case, then this can be a useful trick to have: lax.cond(eqxi.unvmap_any(pred), ...).

pablo2909 commented 11 months ago

Ohh , that looks super useful thanks a lot for the tip ! :)

pablo2909 commented 11 months ago

Pointing out to this. Could be an alternative in some cases