Open inversecrime opened 1 month ago
Thanks for the question!
We have some documentation on this in the autodiff cookbook.
The most concrete way to think about this is to remember that jax.grad
is just a convenience wrapper around jax.vjp
. If you think directly in terms of jax.jvp
and jax.vjp
, all confusion usually vanishes. The error message for your first "Does not work" line attempts to suggest that, saying "For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly." For example, you can call jax.jvp
on your R -> C function with a real input tangent and get a complex output cotangent, or you can call jax.vjp
on your function with a complex output cotangent and get a real input cotangent. For an R->C scalar to scalar function, the jvp
reveals all the derivative information about the function as a single complex number, so it makes sense to use that in this toy example:
import jax
import jax.numpy as jnp
def f(x):
return x**2 + 1j * jnp.sin(x)
y, y_dot = jax.jvp(f, (1.0,), (1.0,))
print(y_dot) # (2+0.5403023j)
If you wanted to reveal all the derivative information about an R->C function using jax.vjp
, you'd have to apply it twice in general, just like for an R -> R^2 function.
More generally, if you had something like R^n -> C, you may want to use reverse-mode for efficiency. You can again apply jax.vjp
just twice to get all the derivative information (whereas using jax.jvp
would require n
calls).
What do you think?
Thanks for the answer! Let's represent a function R->C as a function R->R^2 and use jax transformations accordingly - how would that look like in a real world example? Something like this?
(I don't want to use jvp
since all functions I'm dealing with have exactly one complex-valued output)
(_, f_vjp) = jax.vjp(f, x)
print(f_vjp(1.0 + 0.0j))
print(f_vjp(0.0 + 1.0j))
Is this what jax.grad
does under the hood for holomorphic functions?
And how can I generalize this to functions R^n->C? Should I just write my own "grad" function?
How about this?
x = jnp.asarray(1.0)
def my_grad(f):
def wrapper(x):
assert jnp.issubdtype(x.dtype, jnp.floating)
(y, f_vjp) = jax.vjp(f, x)
assert y.shape == ()
assert jnp.issubdtype(y.dtype, jnp.complexfloating)
return f_vjp(1.0 + 0.0j) + f_vjp(0.0 + 1.0j)
return wrapper
print(my_grad(f)(x))
I just realized that this returns a tuple. Adding the values seems to be wrong. My best guess would be something like this:
return f_vjp(1.0 + 0.0j)[0] + 1.0j * f_vjp(0.0 + 1.0j)[0]
But this also gives a wrong result (conjugated, that is). Somehow, this is a lot more complicated than I thought...
Hi, this is a question and / or a feature request. How can I use jax to calculate the gradient of functions R -> C? Example below:
In this particular example, using "holomorphic=True" and 1.0+0.0j as input works, but I would rather avoid this, since the function does not necessarily have to be holomorphic for the gradient to be well defined.