patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 125 forks source link

Feature Request: Complex-Valued Integration With ZVODE - CVODE in Jax (autodiff) #477

Open onurdanaci opened 1 month ago

onurdanaci commented 1 month ago

Hi,

Unfortunately, all the available (S)ODE integration subroutines in auto-differentiable Python frameworks (RK45, Dopri, etc.) behave very poorly with complex-valued functions [*]. In the Python ecosystem, only Scipy's Fortran wrappers titled ode (ZVODE) and complex_ode (using CVODE) seem to be working fine, but obviously, they are not differentiable and not applicable to the modern applications we love.

I was wondering if anybody wants to adapt these features to Diffrax, and make them auto-differentiable.

[*] https://arxiv.org/abs/2406.06361

patrick-kidger commented 1 month ago

We have some limited support for complex numbers in Diffrax. In particular I think all of the explicit solvers (Tsit5, Dopri etc.) should behave correctly. Glancing at the paper I can see they briefly mention Diffrax, but apparently indicate they had some difficulty getting reverse-mode working. I've not seen a bug report from them though so there's not much I can do until then. 🤷

More importantly though, I believe this whole thing is essentially a non-issue. It's trivial to make any real integrator work with complex numbers: just split into real and imaginary parts before passing your initial condition into the solver, and then combine them back together inside your vector field. Job done.

sriharikrishna commented 3 weeks ago

Hi. Thanks to @onurdanaci for asking the question and to @lockwo for pointing this question out to me. Apologies to @patrick-kidger for not posting the issue earlier (I am the author of the document mentioned above).

I have an MWE below. I would be happy to be told that this issue is minor or that I am using Diffrax incorrectly.

import diffrax
from diffrax import diffeqsolve, ODETerm, Tsit5
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

def solver(y0, t_f, A, use_direct):
  def ode_fn(t, y, B):
    return jnp.matmul(B[0],y)
  term = ODETerm(ode_fn)
  ODEsolver = Tsit5()
  solver_args = dict(t0=0.0, t1=t_f.real, dt0=0.2, y0=y0, args=(A,))
  #Required for forward mode only
  if use_direct == True:
    solver_args |= dict(adjoint=diffrax.DirectAdjoint())
  solution = diffrax.diffeqsolve(term, ODEsolver, **solver_args)
  return solution.ys[0]

def driver(params, use_direct):
  #Create y0 from params
  time = params[0] * jnp.pi
  cos = jnp.cos(time / 2)
  sin = jnp.sin(time / 2)
  axis_angle = params[1] * jnp.pi
  KET_0 = jnp.array([1, 0], dtype=jnp.complex128)  # |0>, spin up
  KET_1 = jnp.array([0, 1], dtype=jnp.complex128)  # |1>, spin down
  y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0

  A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)

  #Evolve y0. Time is influenced by params
  y = solver(y0, time, A, use_direct)
  return y

params_f = jnp.array([0.5,0.4], dtype=jnp.float64)
jacfwd_fun = jax.jacfwd(driver, argnums=(0))
jac_f = jacfwd_fun(params_f, True)

#Must be complex for reverse mode
params_b = jnp.array([0.5+0j,0.4+0j], dtype=jnp.complex128)
#Must set holomorphic=True for reverse mode
jacrev_fun = jax.jacrev(driver, argnums=(0), holomorphic=True)
jac_b = jacrev_fun(params_b, False)

print(jac_f-jac_b)

This generates

/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
[[7.91624188e-09+1.36585934e+06j 1.16415322e-09-2.06637196e-09j]
 [4.65661287e-08+6.69479738e+06j 1.54832378e-08-1.45519152e-11j]]

The problem might be that params influences the initial state of the solver y0 as well the time t1.

Thanks for your attention and help!

onurdanaci commented 2 weeks ago

Dear Patrick @patrick-kidger ,

Thank you for your answer. Indeed I can transform my complex valued system of equations into:

`dvdt = M @ v v = vreal + 1j vimag M = Mreal + 1jMimag

d([vreal; vimag]) = [[Mreal, - Mimag];[Mimag, Mreal]] @ [vreal;vimag]

`

Then combine these two vector fields in post-processing. Of course it would have been much more convenient for the Quantum Technologies communities to have these features are pre-defined in libraries. But, I agree that this part is a non-issue. However, I am still suspicious.

Because the Scipy's VODE subroutine, which was inherited from Fortran libraries, use multi-step implicit Adams methods such as Adams-Moulton method for non-stiff problems, and BDF for stiff problems. I couldn't parse all the archaic Fortran code but my suspicion is that Scipy's ZVODE just use this VODE library by implementing your vector-field trick.

I have doubts, based on some small (but not systematic, elaborate or conclusive at any metric) numerical experiments and the paper that I shared before, that the cream de la cream explicit Runge-Kutta methods y'all provide such as Tsit5 and Dopri5 would be as good for the said non-stiff quantum problems as implicit Adams. Or, KenCarp4 would be as good as BDF for stiff problems. Maybe I am wrong. I will need to use them on some important unit tests to make sure that I do not get non-physical results. I will get back to you.

patrick-kidger commented 2 weeks ago

Thank you @sriharikrishna for the MWE! That's really useful. I'm going to tag @randl as our resident complex autodiff expert. Any thoughts?

Other than that, thank you @onurdanaci for your write-up above! I'd like it if Diffrax could be useful to you regardless :)

Randl commented 2 weeks ago

@sriharikrishna Isn't the mismatch since, in the first case, you calculate the gradient with respect to a real parameter, which is automatically real, and in the second case, the gradient is with respect to a complex parameter, thus it also has an imaginary part? I've tried running check_grads for a function equivalent to yours:

@pytest.mark.parametrize(
    "solver",
    [
        diffrax.Tsit5(),
    ],
)
def test_grad_complex(solver):

    def ode_fn(t, y, B):
        return jnp.matmul(B[0], y)

    term = ODETerm(ode_fn)
    @partial(jax.jit)
    def driver(pt, ang):
        # Create y0 from params
        time = pt * jnp.pi
        cos = jnp.cos(time / 2)
        sin = jnp.sin(time / 2)
        axis_angle = ang * jnp.pi
        KET_0 = jnp.array([1, 0], dtype=jnp.complex128)  # |0>, spin up
        KET_1 = jnp.array([0, 1], dtype=jnp.complex128)  # |1>, spin down
        y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0
        jax.debug.print("{y0}",y0=y0)

        A = jnp.array([[0 - 1j, 1.0 + 2j],
                       [- 100.0 + 3j, 0 + 4j]], dtype=jnp.complex128)

        solver_args = dict(t0=0.0, t1=time.real, dt0=0.2, y0=y0, args=(A,))
        # # Required for forward mode only
        # if use_direct == True:
        solver_args |= dict(adjoint=diffrax.DirectAdjoint())
        # Evolve y0. Time is influenced by params
        solution = diffrax.diffeqsolve(term, solver, **solver_args)
        return solution.ys[0]

    # check_grads(driver, (0.5,0.4), order=2, modes=["fwd"])
    check_grads(driver, (0.5+0.j,0.4+0.j), order=2, modes=["rev"], atol=1e15)

Up to the fact that absolute differences are huge in rev case, I couldn't see a fail. If you could point out the mismatch vs numerical gradients (alternatively, there may be bug in the solver itself, which makes both analytic and numeric gradients wrong), that'd be helpful.