Open alleSini99 opened 2 weeks ago
I think you probably want to wrap your operations in a jax.custom_jvp
. This kind of thing isn't really related to either Equinox or a Quax. :)
Side note, I can see you're using __jax_array__
. I'd recommend not using this -- it was an early experiment, deliberately undocumented, and is not well-supported throughout JAX.
@patrick-kidger I'm not sure that's exactly what we want to do...
We want to have some exp
and log()
functions that analytically cancel each other (in general, but especially when doing AD).
Right now, if we do AD through $$g(x)=\exp(\log(f(x)) $$ we get something that is $$\nabla g(x) = \exp(\log(f(x)) \cdot \nabla \log(f(x)) = \exp(\log(f(x)) \cdot \frac{\nabla f(x)}{f(x)} $$.
This is analytically equivalent to $$\nabla g(x)=\nabla f(x)$$ if $$f(x) \neq0$$.
But if $$f(x)=0$$, Jax does the chain rule so multiplies and divide by $$f(x)$$ and so we get a nan, because we divide by zero.
My idea was that by creating a custom 'symbolic-log' type we could symbolically eliminate exp(log(f(x))
before we lower to jaxpr
, pretty much like the analytical zero example of quax propagates symbolically zeros before they are lowered in the jaxpr.
We'd be happy to use jax.custom_jvp/vjp
but I'm unsure how, in this case? even if we define a special 'exp' function with a custom AD rule, we cannot check whether the internal function is the logarithm and do the simplification analytically?
Ah, okay! That makes sense.
So I think hoping to do things before you hit the jaxpr is likely to be tricky. That's just not going to compose well in all kinds of ways.
My first thought for how you might try and tackle this is to write a custom jaxpr-to-jaxpr function (directly, not using Quax), that checks for the appropriate pattern and cancels things as appropriate. In compiler language this is basically a peephole optimization (aka algebraic pattern matching).
That would need a bit of care to still do the cancellation when you have, for example, a log at the start of a loop body and an exp at an end, but I think should still be essentially doable.
I hope that helps!
Hello, I would like to overload some
jax
operations to work with cases where the derivative of a functionf
is undefined/diverging, but the derivative of a composite function off
, let's sayg(f)
, is defined. I report here an example. In this casef
is-inf
and soexp(f)
is0
. If I compute the derivative off
withjax.grad
this isinf
as expected, so the derivative ofexp(f)
isnan
since this is computed from the chain rule asd exp(f) = exp(f) df
and I get a0
timesinf
. However, thanks to mathematical simplifications, the correctd exp(f)
is finite and I would like to obtain this as the result. We have partially solved the problem by defining special versions oflog
andexp
based onequinox
classes, but is there a general way to handle this issue for arbitraryjax
operations withinquad
? Example:@philipvinc