patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
106 stars 3 forks source link

Overloading `jax` operations to avoid undefined derivatives #35

Open alleSini99 opened 2 weeks ago

alleSini99 commented 2 weeks ago

Hello, I would like to overload some jax operations to work with cases where the derivative of a function f is undefined/diverging, but the derivative of a composite function of f, let's say g(f), is defined. I report here an example. In this case f is -inf and so exp(f) is 0. If I compute the derivative of f with jax.grad this is inf as expected, so the derivative of exp(f) is nan since this is computed from the chain rule as d exp(f) = exp(f) df and I get a 0 times inf. However, thanks to mathematical simplifications, the correct d exp(f) is finite and I would like to obtain this as the result. We have partially solved the problem by defining special versions of log and exp based on equinox classes, but is there a general way to handle this issue for arbitrary jax operations within quad? Example:

import jax.numpy as jnp 
import jax 
import equinox as eqx

def psii(theta, x): 
    return theta * x
thetas = jnp.array([1., 1.])
def f(theta, x): 
    return jnp.log(psii(theta[0], x) - psii(theta[1], x))
expf = lambda thetas, x: jnp.exp(f(thetas, x))
x = 0.5
print("f:", f(thetas, x))
print("expf:", expf(thetas, x))
print("df:", jax.grad(f)(thetas, x))
print("d expf:", jax.grad(expf)(thetas, x))

class MagicLog(eqx.Module):
    val : jax.Array
    def __init__(self, x):
        self.val = x
    def __jax_array__(self, dtype=None): 
        return jnp.log(self.val)
def magiclog(x):
    return MagicLog(x)
def magicexp(x):
    if isinstance(x, MagicLog):
        return x.val
    else:
        return jnp.exp(x)
def magiclogpsi(theta, x): 
    return magiclog(psii(theta[0], x) - psii(theta[1], x))
magicexpf = lambda thetas, x: magicexp(magiclogpsi(thetas, x))
print("magic expf:", magicexpf(thetas, x))
print("magic d expf:", jax.grad(magicexpf)(thetas, x))

@philipvinc

patrick-kidger commented 2 weeks ago

I think you probably want to wrap your operations in a jax.custom_jvp. This kind of thing isn't really related to either Equinox or a Quax. :)

Side note, I can see you're using __jax_array__. I'd recommend not using this -- it was an early experiment, deliberately undocumented, and is not well-supported throughout JAX.

PhilipVinc commented 2 weeks ago

@patrick-kidger I'm not sure that's exactly what we want to do...

We want to have some exp and log() functions that analytically cancel each other (in general, but especially when doing AD).

Right now, if we do AD through $$g(x)=\exp(\log(f(x)) $$ we get something that is $$\nabla g(x) = \exp(\log(f(x)) \cdot \nabla \log(f(x)) = \exp(\log(f(x)) \cdot \frac{\nabla f(x)}{f(x)} $$.

This is analytically equivalent to $$\nabla g(x)=\nabla f(x)$$ if $$f(x) \neq0$$.

But if $$f(x)=0$$, Jax does the chain rule so multiplies and divide by $$f(x)$$ and so we get a nan, because we divide by zero. My idea was that by creating a custom 'symbolic-log' type we could symbolically eliminate exp(log(f(x)) before we lower to jaxpr, pretty much like the analytical zero example of quax propagates symbolically zeros before they are lowered in the jaxpr.

We'd be happy to use jax.custom_jvp/vjp but I'm unsure how, in this case? even if we define a special 'exp' function with a custom AD rule, we cannot check whether the internal function is the logarithm and do the simplification analytically?

patrick-kidger commented 2 weeks ago

Ah, okay! That makes sense.

So I think hoping to do things before you hit the jaxpr is likely to be tricky. That's just not going to compose well in all kinds of ways.

My first thought for how you might try and tackle this is to write a custom jaxpr-to-jaxpr function (directly, not using Quax), that checks for the appropriate pattern and cancels things as appropriate. In compiler language this is basically a peephole optimization (aka algebraic pattern matching).

That would need a bit of care to still do the cancellation when you have, for example, a log at the start of a loop body and an exp at an end, but I think should still be essentially doable.

I hope that helps!