google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.82k stars 2.73k forks source link

Complex gradient #22948

Open inversecrime opened 1 month ago

inversecrime commented 1 month ago

Hi, this is a question and / or a feature request. How can I use jax to calculate the gradient of functions R -> C? Example below:

import jax
import jax.numpy as jnp

def f(x):
    return x**2 + 1j * jnp.sin(x)

print(jax.grad(f)(1.0))  # Does not work (because of complex output)
print(jax.grad(f, holomorphic=True)(1.0))  # Does not work either (because of real output)

In this particular example, using "holomorphic=True" and 1.0+0.0j as input works, but I would rather avoid this, since the function does not necessarily have to be holomorphic for the gradient to be well defined.

mattjj commented 1 month ago

Thanks for the question!

We have some documentation on this in the autodiff cookbook.

The most concrete way to think about this is to remember that jax.grad is just a convenience wrapper around jax.vjp. If you think directly in terms of jax.jvp and jax.vjp, all confusion usually vanishes. The error message for your first "Does not work" line attempts to suggest that, saying "For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly." For example, you can call jax.jvp on your R -> C function with a real input tangent and get a complex output cotangent, or you can call jax.vjp on your function with a complex output cotangent and get a real input cotangent. For an R->C scalar to scalar function, the jvp reveals all the derivative information about the function as a single complex number, so it makes sense to use that in this toy example:

import jax
import jax.numpy as jnp

def f(x):
    return x**2 + 1j * jnp.sin(x)

y, y_dot = jax.jvp(f, (1.0,), (1.0,))
print(y_dot)  # (2+0.5403023j)

If you wanted to reveal all the derivative information about an R->C function using jax.vjp, you'd have to apply it twice in general, just like for an R -> R^2 function.

More generally, if you had something like R^n -> C, you may want to use reverse-mode for efficiency. You can again apply jax.vjp just twice to get all the derivative information (whereas using jax.jvp would require n calls).

What do you think?

inversecrime commented 1 month ago

Thanks for the answer! Let's represent a function R->C as a function R->R^2 and use jax transformations accordingly - how would that look like in a real world example? Something like this?

(I don't want to use jvp since all functions I'm dealing with have exactly one complex-valued output)

(_, f_vjp) = jax.vjp(f, x)
print(f_vjp(1.0 + 0.0j))
print(f_vjp(0.0 + 1.0j))

Is this what jax.grad does under the hood for holomorphic functions?

And how can I generalize this to functions R^n->C? Should I just write my own "grad" function?

How about this?

x = jnp.asarray(1.0)

def my_grad(f):
    def wrapper(x):
        assert jnp.issubdtype(x.dtype, jnp.floating)
        (y, f_vjp) = jax.vjp(f, x)
        assert y.shape == ()
        assert jnp.issubdtype(y.dtype, jnp.complexfloating)
        return f_vjp(1.0 + 0.0j) + f_vjp(0.0 + 1.0j)
    return wrapper

print(my_grad(f)(x))
inversecrime commented 1 month ago

I just realized that this returns a tuple. Adding the values seems to be wrong. My best guess would be something like this:

        return f_vjp(1.0 + 0.0j)[0] + 1.0j * f_vjp(0.0 + 1.0j)[0]

But this also gives a wrong result (conjugated, that is). Somehow, this is a lot more complicated than I thought...