patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Switched from reverse mode to forward mode where possible. #61

Open patrick-kidger opened 1 month ago

patrick-kidger commented 1 month ago

This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix https://github.com/patrick-kidger/optimistix/pull/51#issuecomment-2124401072.

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the new test_least_squares.py::test_residual_jac that I've added actually fails! I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!

Here's some context:

Tagging @Randl -- do you have a clearer idea of what's going on here?

Randl commented 1 month ago

As for complex conjugate, isn't that expected? The complex derivative is given by d/dz=d/dx-i*d/dy See, e.g., https://pytorch.org/docs/stable/notes/autograd.html#complex-autograd-doc or https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#complex-numbers-and-differentiation

We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).
patrick-kidger commented 1 month ago

We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).

Indeed, gradient descent should be performed by performing steps in the direction of the conjugate Wirtinger derivative. That's an optimisation thing though, and that's not what we're doing here: we're just computing the gradient.

And when it comes to computing the gradient, then JAX and PyTorch do different things: AFAIK JAX computes the (unconjugated) Wirtinger derivative, whilst PyTorch computes the conjugate Wirtinger derivative.

So what we're seeing here is that we (Optimistix) are sometimes computing the conjugate Wirtinger derivative, and sometimes we're computing the (unconjugated) Wirtinger derivative. We should always compute the latter?

Randl commented 1 month ago

When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns d/dz=d/dx-i*d/dy and you have a tuple of d/dx, d/dy:

assert tree_allclose((grad1.real, grad1.imag), grad2)

So there is an extra sign here.

Same in calculating gradient, of course it depends on how you plan to use it but you sum two conjugate values so basically take the real part.

patrick-kidger commented 3 weeks ago

When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns d/dz=d/dx-i*d/dy

Right, but I don't think that's the case:

> jax.grad(lambda z: z**2, holomorphic=True)(jax.numpy.array(1+1j))
Array(2.+2.j, dtype=complex64, weak_type=True)

contrast the same computation in PyTorch, which does additionally conjugate:

> import torch
> x = torch.tensor(1+1j, requires_grad=True)
> y = x ** 2
> torch.autograd.backward(y, grad_tensors=(torch.tensor(1+0j),))
> x.grad
tensor(2.-2.j)

So we still have a conjugation bug somewhere.

Randl commented 3 weeks ago

Again, this is convention rather than bug. Pytorch returns df/dz* while Jax returns df/dz.

patrick-kidger commented 3 weeks ago

I agree it's a convention! And it's a convention we're breaking. I believe we're returning df/dz*, despite being in JAX, and that therefore we should be returning a df/dz.

(Do you agree we're returning df/dz*? Or do you think we're returning df/dz?)

Randl commented 3 weeks ago

Return where, sorry? Your current compute_grad function computes the real part of the gradient, the real-valued version computes the conjugate gradient (more precisely real and imaginary parts of it), and Jax computes the regular gradient, as far as I can tell. If you want to follow the Jax conventions, the compute_grad should return a regular gradient, and the tests against the real version should have an extra minus sign.

Maybe let's try to specify where exactly you think the bug is. What is the minimal failing test?

Randl commented 3 weeks ago

Oh, I think I understand where the confusion comes from. More precisely, PyTorch returns (df/dz)*, which equals df/dz* only for C->R functions. If your z^2 function is part of a larger computation with real output, this is ok (but not in general). In your example, f is C->C. In general, we can't represent gradient as a single number in this case, only if f is holomorphic (i.e., depend only on z and not z*) or anti-holomorphic. Your f is holomorphic, and indeed d/dz z^2 = 2z (what Jax returns), while d/dz* z^2 = 0. What PyTorch returns is (d/dz z^2)* = 2z*, which matches the convention.