I think that the Zygote
problem is related to the update of the ScalarOperator
here. Where the struct contains a Float64
, but during the differentiation it is trying to convert its field val
to a Dual
number, thus changing the type of the structure.
I found a possible partial fix for the out-of-place case. If I consider the out-of-place ODEProblem prob = ODEProblem{false}(U, u0, tspan, [γ])
, and I update of the ScalarOperator
is updated in the out-of-place case
function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func)
instead of the current implementation
function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...)
L.val = L.update_func(L.val, u, p, t; kwargs...)
function update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
update_coefficients!(L, u, p, t; kwargs...)
Then it works
Zygote.gradient(my_f, 1.9) # (-0.17161488226273966,)
However, it still fails for the Enzyme case and for the in-place version of Zygote (which would be much more efficient I guess).
I will make a PR to SciMLOperators.jl for fix at least this case.
I found that the in-place version works when using ComplexF64
using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using SciMLOperators
using Zygote
using Enzyme
using SciMLSensitivity
T = ComplexF64
const N = 10
const u0 = ones(T, N)
# H_tmp = rand(T, N, N)
H_tmp = sprand(T, N, N, 0.5)
const H = H_tmp + H_tmp'
const U = ScalarOperator(one(params[1]), coef) * MatrixOperator(Diagonal(H)) + MatrixOperator(Diagonal(H))
coef(a, u, p, t) = - p[1]
function my_f(params)
tspan = (0.0, 1.0)
# prob = ODEProblem{true}(U, u0, tspan, [γ], sensealg = InterpolatingAdjoint(autojacvec=false))
prob = ODEProblem{true}(U, u0, tspan, params)
sol = solve(prob, Tsit5())
return real(sol.u[end][end])
params = T[1]
my_f(params) # 0.25621142049273665
Zygote.gradient(my_f, params)
But I get the following warnings during the differentiation
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/ME3jV/src/concrete_solve.jl:67
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/ME3jV/src/concrete_solve.jl:207
(ComplexF64[-1.9743163253371472 + 0.0im],)
First, how can I know the sensealg
(and its options) that is automatically chosen?
Then, although I think the problem is related to ReverseDiff.jl, I followed this page and run
tspan = (0.0, 1.0)
# prob = ODEProblem{true}(U, u0, tspan, [γ], sensealg = InterpolatingAdjoint(autojacvec=false))
prob = ODEProblem{true}(U, u0, tspan, params)
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)
if DiffEqBase.isinplace(prob)
_f = prob.f
_f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
which returns nothing for every variable
((nothing, nothing, nothing, nothing),)
Now that I better understand how Enzyme.jl works, this is not a bad thing. It means that Enzyme can differentiate the function, right? So I don't understand why I see a warning on ReverseDiffVJP
I guess that I can try using autojacvec=EnzymeVJP()
, but I don't know what is the used sensealg.
Moreover, can I use Enzyme
instead of Zygote
to directly differentiate the my_f
function? What are the advantages and the disadvantages?
If I try sensealg = BacksolveAdjoint(autojacvec=EnzymeVJP())
, I get the error
ERROR: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…}, ::Const{…}, ::Const{…})
The function `augmented_primal` exists, but no method is defined for this combination of argument types.
Closest candidates are:
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{typeof(mul!)}, ::Type{RT}, ::Annotation{<:StridedVecOrMat}, ::Const{<:Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int64}}, I}} where I<:(AbstractUnitRange{<:Integer})} where {Tv, Ti}}, ::Annotation{<:StridedVecOrMat}, ::Annotation{<:Number}, ::Annotation{<:Number}) where RT
@ Enzyme ~/.julia/packages/Enzyme/azJki/src/internal_rules.jl:732
augmented_primal(::Any, ::Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::Annotation{T}...; kws...) where {RT, T}
@ QuadGKEnzymeExt ~/.julia/packages/QuadGK/BjmU0/ext/QuadGKEnzymeExt.jl:6
augmented_primal(::Any, ::Const{typeof(NNlib._dropout!)}, ::Type{RT}, ::Any, ::OutType, ::Any, ::Any, ::Any) where {OutType, RT}
@ NNlibEnzymeCoreExt ~/.julia/packages/NNlib/CkJqS/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl:318
It looks like you're hitting a missing rule in enzyme for sparse mul!
what's the full types of the method match failure? And what version of enzyme are you using (it is the latest)?
LoadError: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{1, false, false, (false, false, false, false, false, false), false}, ::Const{typeof(mul!)}, ::Type{Const{Vector{ComplexF64}}}, ::Duplicated{Vector{ComplexF64}}, ::Duplicated{SparseMatrixCSC{ComplexF64, Int64}}, ::Duplicated{Vector{ComplexF64}}, ::Const{Bool}, ::Const{Bool})
The function `augmented_primal` exists, but no method is defined for this combination of argument types.
Closest candidates are:
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{typeof(mul!)}, ::Type{RT}, ::Annotation{<:StridedVecOrMat}, ::Const{<:Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int64}}, I}} where I<:(AbstractUnitRange{<:Integer})} where {Tv, Ti}}, ::Annotation{<:StridedVecOrMat}, ::Annotation{<:Number}, ::Annotation{<:Number}) where RT
@ Enzyme ~/.julia/packages/Enzyme/azJki/src/internal_rules.jl:732
augmented_primal(::Any, ::Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::Annotation{T}...; kws...) where {RT, T}
@ QuadGKEnzymeExt ~/.julia/packages/QuadGK/BjmU0/ext/QuadGKEnzymeExt.jl:6
augmented_primal(::Any, ::Const{typeof(NNlib._dropout!)}, ::Type{RT}, ::Any, ::OutType, ::Any, ::Any, ::Any) where {OutType, RT}
@ NNlibEnzymeCoreExt ~/.julia/packages/NNlib/CkJqS/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl:318`
And this is my versioninfo
Status `~/GitHub/Research/Undef/Autodiff QuantumToolbox/Project.toml`
[6e4b80f9] BenchmarkTools v1.5.0
[13f3f980] CairoMakie v0.12.16
[b0b7db55] ComponentArrays v0.15.19
[7da242da] Enzyme v0.13.15
[f6369f11] ForwardDiff v0.10.38
[1dea7af3] OrdinaryDiffEq v6.90.1
[33c8b6b6] ProgressLogging v0.1.4
[6c2fb7c5] QuantumToolbox v0.21.5 `~/.julia/dev/QuantumToolbox`
[731186ca] RecursiveArrayTools v3.27.3
[37e2e3b7] ReverseDiff v1.15.3
⌃ [0bca4576] SciMLBase v2.61.0
⌃ [1ed8b502] SciMLSensitivity v7.71.1
[5d786b92] TerminalLoggers v0.1.7
[e88e6eb3] Zygote v0.6.73
Info Packages marked with ⌃ have new versions available and may be upgradable.
All the examples in the
Documentation use a user-defined function for the ODEProblem. I need instead to define a parameter-dependentSciMLOperator
(e.g., aMatrixOperator
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)