SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
333 stars 71 forks source link

Zero gradient when using `ComplexF64` ComponentArrays as Parameters #1149

Open albertomercurio opened 1 week ago

albertomercurio commented 1 week ago

Describe the bug 🐞

The gradient returns a null vector when computing the gradient using complex ComponentArrays as Parameters.

Expected behavior

A non-null gradient result.

Minimal Reproducible Example 👇

using OrdinaryDiffEq
using Zygote
using SciMLSensitivity
using ComponentArrays

##

function lotka_volterra(u, p, t)
  dx = p[1] * u[1] - p[2] * u[1] * u[2]
  dy = -p[3] * u[2] + p[4] * u[1] * u[2]

  return [dx, dy]
end

function my_f(p)
  u0 = ComplexF64[1.0, 1.0]

  param = p
  # param = ComponentArray((p1 = p[1], p2 = p[2], p3 = p[3], p4 = p[4],))

  prob = ODEProblem{false}(lotka_volterra, u0, (0.0, 10.0), param)
  sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)
  return sum(real, sol.u[end])
end

p = [0.5, 0.6, 0.7, 0.8]
my_f(p)     # 1.9740765358083163

##

Zygote.gradient(my_f, p)   # ([0.34232539073366874, 0.4509928030672796, 1.3335914824686044, -0.7728475749882253],)

##

But if I use param = ComponentArray((p1 = p[1], p2 = p[2], p3 = p[3], p4 = p[4],)) instead, I get a null vector

Zygote.gradient(my_f, p)   # ([0.0, 0.0, 0.0, 0.0],)

The same doesn't happen for a Float64 initial vector.

Environment (please complete the following information):

Status `~/GitHub/Research/Undef/Autodiff QuantumToolbox/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [b0b7db55] ComponentArrays v0.15.18
  [7da242da] Enzyme v0.13.15
  [1dea7af3] OrdinaryDiffEq v6.90.1
  [33c8b6b6] ProgressLogging v0.1.4
  [6c2fb7c5] QuantumToolbox v0.21.5 `~/.julia/dev/QuantumToolbox`
  [1ed8b502] SciMLSensitivity v7.71.1
  [53ae85a6] SciMLStructures v1.5.0
  [5d786b92] TerminalLoggers v0.1.7
  [e88e6eb3] Zygote v0.6.73
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 16 default, 0 interactive, 8 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 16
ChrisRackauckas commented 4 days ago

This is worrisome. @DhairyaLGandhi do you know what could cause this one?

DhairyaLGandhi commented 4 days ago

I'm taking a look

DhairyaLGandhi commented 3 days ago

Looks like when passing in a ComponentVector, the DiffEqCallbacks machinery doesn't actually run any of the internal recursive_* methods properly and therefore doesn't update the accumulation_cache in https://github.com/SciML/DiffEqCallbacks.jl/blob/735ea17a5dc9618e7309b30db4fc3e85c416f0dc/src/integrating_sum.jl#L52. This is likely due to all structs being assumed functors in Functors@0.5

Also important to note that https://github.com/SciML/DiffEqCallbacks.jl/blob/735ea17a5dc9618e7309b30db4fc3e85c416f0dc/src/integrating_sum.jl#L39 returns (p1 = nothing, p2 = nothing, p3 = nothing, p4 = nothing).

Bumping Functors to v0.4 returns the gradients as expected, but we need to check the values here.

julia> Zygote.gradient(my_f, p2)
((p1 = -8.15610781018363, p2 = -1.7862859982677435, p3 = 7.645628735586522, p4 = -12.55380170354163),)

This change in behaviour is introduced in Functors@0.5.

Likely also related to https://github.com/SciML/DiffEqCallbacks.jl/issues/239

ChrisRackauckas commented 3 days ago

@avik-pal did you end up looking into this?