Open patrick-kidger opened 1 month ago
As for complex conjugate, isn't that expected? The complex derivative is given by d/dz=d/dx-i*d/dy
See, e.g., https://pytorch.org/docs/stable/notes/autograd.html#complex-autograd-doc or https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#complex-numbers-and-differentiation
We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).
We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).
Indeed, gradient descent should be performed by performing steps in the direction of the conjugate Wirtinger derivative. That's an optimisation thing though, and that's not what we're doing here: we're just computing the gradient.
And when it comes to computing the gradient, then JAX and PyTorch do different things: AFAIK JAX computes the (unconjugated) Wirtinger derivative, whilst PyTorch computes the conjugate Wirtinger derivative.
So what we're seeing here is that we (Optimistix) are sometimes computing the conjugate Wirtinger derivative, and sometimes we're computing the (unconjugated) Wirtinger derivative. We should always compute the latter?
When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns d/dz=d/dx-i*d/dy
and you have a tuple of d/dx
, d/dy
:
assert tree_allclose((grad1.real, grad1.imag), grad2)
So there is an extra sign here.
Same in calculating gradient, of course it depends on how you plan to use it but you sum two conjugate values so basically take the real part.
When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns
d/dz=d/dx-i*d/dy
Right, but I don't think that's the case:
> jax.grad(lambda z: z**2, holomorphic=True)(jax.numpy.array(1+1j))
Array(2.+2.j, dtype=complex64, weak_type=True)
contrast the same computation in PyTorch, which does additionally conjugate:
> import torch
> x = torch.tensor(1+1j, requires_grad=True)
> y = x ** 2
> torch.autograd.backward(y, grad_tensors=(torch.tensor(1+0j),))
> x.grad
tensor(2.-2.j)
So we still have a conjugation bug somewhere.
Again, this is convention rather than bug. Pytorch returns df/dz*
while Jax returns df/dz
.
I agree it's a convention! And it's a convention we're breaking. I believe we're returning df/dz*
, despite being in JAX, and that therefore we should be returning a df/dz
.
(Do you agree we're returning df/dz*
? Or do you think we're returning df/dz
?)
Return where, sorry? Your current compute_grad
function computes the real part of the gradient, the real-valued version computes the conjugate gradient (more precisely real and imaginary parts of it), and Jax computes the regular gradient, as far as I can tell. If you want to follow the Jax conventions, the compute_grad
should return a regular gradient, and the tests against the real version should have an extra minus sign.
Maybe let's try to specify where exactly you think the bug is. What is the minimal failing test?
Oh, I think I understand where the confusion comes from. More precisely, PyTorch returns (df/dz)*
, which equals df/dz*
only for C->R
functions. If your z^2
function is part of a larger computation with real output, this is ok (but not in general).
In your example, f
is C->C
. In general, we can't represent gradient as a single number in this case, only if f
is holomorphic (i.e., depend only on z
and not z*
) or anti-holomorphic. Your f
is holomorphic, and indeed d/dz z^2 = 2z
(what Jax returns), while d/dz* z^2 = 0
. What PyTorch returns is (d/dz z^2)* = 2z*
, which matches the convention.
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix https://github.com/patrick-kidger/optimistix/pull/51#issuecomment-2124401072.
Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the new
test_least_squares.py::test_residual_jac
that I've added actually fails! I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!Here's some context:
r(y)
. The goal is then to optimise the minimisation problem0.5*||r(y)||^2
where||.||
denotes the 2-norm. Whilst you could do this viaoptx.minimise
, in practice there are specialised least-squares algorithms that exploit having access tor
. (Indeed computing this is what occurs inFunctionInfo.ResidualJac.as_min
, in cases where directly working with the residualsr
is not necessary.)f(y) = 0.5 * r(y)^T conj(r(y))
.FunctionInfo.ResidualJac.compute_grad
computes the derivativedf/dy = 0.5 * (r^T dconj(r)/dy + dr^T/dy conj(r))
. The jacobiandr/dy
is available asFunctionInfo.ResidualJac.jac
jax.grad
directly onFunctionInfo.ResidualJac.as_min
differ by a conjugate! This is the problem :(Function.ResidualJac.compute_grad_dot
that computes the dot-productdf/dy . z
against some vectorz
, i.e.df/dy^T conj(z)
. I haven't debugged/looked at this at all yet due to the earlier error.)Tagging @Randl -- do you have a clearer idea of what's going on here?