Open albertomercurio opened 2 weeks ago
There are a couple compounding factors at play here. First its inconsistently defined. f
is written expecting parameters of type MyParameters
whereas what is passed is a vector. There's no conversion happening either because a vector is treated as a SciMLStructure by itself.
Second is that when we were working on #1135 we were missing https://github.com/JuliaArrays/ArrayInterface.jl/pull/456.
Third is that the way MyParameters
stores its parameters vs the adjoint we get is inconsistent. MyParameters stores it as a Tuple whereas we get a vector back when calculating the parameter jacobian.
should also be solved by #1147
Hi @DhairyaLGandhi,
I tried with the branch of #1147, and I get a different error instead
p = rand(T, 4)
Zygote.gradient(my_f, p)
ERROR: MethodError: no method matching recursive_copyto!(::Vector{ComplexF64}, ::NTuple{4, ComplexF64})
The function `recursive_copyto!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
recursive_copyto!(::Tuple, ::Tuple)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:11
recursive_copyto!(::AbstractArray, ::AbstractArray)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:9
recursive_copyto!(::Any, ::Nothing)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:16
...
Stacktrace:
[1] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:492
[2] GaussIntegrand
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:517 [inlined]
[3] (::SciMLSensitivity.var"#265#266"{…})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:558
[4] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/00gNi/src/integrating_sum.jl:50
[5] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/frOsk/src/callbacks.jl:615 [inlined]
[6] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/frOsk/src/callbacks.jl:631 [inlined]
[7] handle_callbacks!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:355
[8] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:243
[9] loopfooter!
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:207 [inlined]
[10] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:579
[11] #__solve#75
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:7 [inlined]
[12] __solve
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:1 [inlined]
[13] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:612
[14] solve_call
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:569 [inlined]
[15] #solve_up#53
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1092 [inlined]
[16] solve_up
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1078 [inlined]
[17] #solve#51
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1015 [inlined]
[18] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:578
[19] _adjoint_sensitivities
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:531 [inlined]
[20] #adjoint_sensitivities#63
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/sensitivity_interface.jl:401 [inlined]
[21] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#315"{…})(Δ::ODESolution{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/concrete_solve.jl:627
[22] ZBack
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:212 [inlined]
[23] (::Zygote.var"#kw_zpullback#56"{…})(dy::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:238
[24] #294
@ ~/.julia/packages/Zygote/nyzjS/src/lib/lib.jl:206 [inlined]
[25] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[26] #solve#51
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1015 [inlined]
[27] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[28] #294
@ ~/.julia/packages/Zygote/nyzjS/src/lib/lib.jl:206 [inlined]
[29] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[30] solve
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1005 [inlined]
[31] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[32] my_f
@ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/autodiff.jl:158 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[34] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:91
[35] gradient(f::Function, args::Vector{ComplexF64})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:148
[36] top-level scope
@ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/autodiff.jl:167
Some type information was truncated. Use `show(err)` to see complete types.
It's very strange because canonicalize
returns a Vector
for the buffer
.
Yes, that's the third point from my comment. The adjoint is a Tuple since that's how the struct is stored in memory. We can add a dispatch to recursive_copyto
but I worry that it's slightly ambiguous. I'll check out if there are any corner cases worth worrying about.
Ok. But still I don't understand why the Float64
case works.
I don't know if #1149 is also related to this, where I get a null-gradient when using complex ComponentArray
rather than float .
Describe the bug 🐞
The calculation of the gradient on a ODE of
Float64
type works when using params as bothVector
or a customstruct
(using SciMLStructures.jl). However, it fails when I simply change the type of the ODE toComplexF64
.It seems that, in the
CompleF64
case, it converts the parameters to aVector
. But they are a customstruct
, sop.p1
doesn't work.It works when using a
Vector
instead of a customstruct
.Expected behavior
Returning the correct gradient as in the
Float64
or as in theComplexF64
case withVector
parameters.Minimal Reproducible Example 👇
Definition of the custom struct
ODE Problem
Gradient Calculation (fails)
Error & Stacktrace ⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()