Closed Qfl3x closed 1 year ago
After defining
function ChainRulesCore.rrule(::Type{<:SciMLBase.PDETimeSeriesSolution}, sol, meta)
function PDETimeSeriesSolutionAdjoint(ȳ)
(ChainRulesCore.NoTangent(), ȳ, ChainRulesCore.NoTangent())
end
SciMLBase.PDETimeSeriesSolution(sol, meta), PDETimeSeriesSolutionAdjoint
end
Zygote.@adjoint function Zygote.literal_getproperty(sol::DiffEqBase.AbstractTimeseriesSolution,
::Val{:u})
solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
_Δ = @. ifelse(Δ == nothing, (zerou,), Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ), nothing)
end
sol.u, solu_adjoint
end
gives
ERROR: DimensionMismatch("variable with size(x) == (9, 6) cannot have a gradient with size(dx) == (66,)")
Stacktrace:
[1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})(dx::ODESolution{Float64, 2, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, SciMLBase.LinearInterpolation{Vector{Float64}, Matrix{Float64}}, Nothing, Nothing})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:227
[2] _project
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:189 [inlined]
[3] map(f::typeof(Zygote._project), t::Tuple{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, OrdinaryDiffEq.InterpolationData{ODEFunction{true,SciMLBase.AutoSpecialize,…}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Tsit5Cache{Vector{Float64},…}}, DiffEqBase.DEStats, Nothing}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, s::Tuple{ODESolution{Float64, 2, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, SciMLBase.LinearInterpolation{Vector{Float64}, Matrix{Float64}}, Nothing, Nothing}, Nothing})
@ Base ./tuple.jl:247
[4] gradient(::Function, ::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, OrdinaryDiffEq.InterpolationData{ODEFunction{true,SciMLBase.AutoSpecialize,…}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Tsit5Cache{Vector{Float64},…}}, DiffEqBase.DEStats, Nothing}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:98
[5] top-level scope
@ REPL[13]:1
It is curious that we call
SciMLBase.wrap_sol(::PDETimeSeriesSolution, metadata)
at all.
It is curious that we call
Why? It's required for wrapping the PDE solution.
solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])
What is this one for?
The last piece of the puzzle may be the missing https://github.com/SciML/SciMLSensitivity.jl/issues/760
Why? It's required for wrapping the PDE solution
It is forwarding to the PDETimeSeriesSolution
constructor. Is it defined to accept a PDETimeSeriesSolution
? I see it defined for AbstractODESolution
What is this one for?
The pullback you get in this case is a dictionary with the observed variables. It does seem like it's missing an adjoint for those.
It is forwarding to the PDETimeSeriesSolution constructor. Is it defined to accept a PDETimeSeriesSolution? I see it defined for AbstractODESolution
Oh yes, it's never called on a PDE solution.
I tried to find a workaround for this by using the keyword argument wrap = Val(false)
in solve
to get the original ODE solution back, compare here `https://docs.sciml.ai/MethodOfLines/stable/solutions/#Original-solution
Then, just summing the solution and taking the gradient works with ForwardDiff.gradient
, but not with Zygote.gradient
which errors with
MethodError: no method matching +(::Num, ::Matrix{Float64})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
I changed the last lines to
function result(p)
# Solve ODE problem
sol = solve(prob, Tsit5(), p=p, saveat=0.2; wrap = Val(false))
sum(sum(sol.u))
end
result([3.]) # First pass
ForwardDiff.gradient(result, [3.,])
Zygote.gradient(result, [3.,])
I don't quite get what is going wrong.
Dear @DhairyaLGandhi, could you share if you have made any modifications to the code in the original post. When I add your lines, I still get
ERROR: MethodError: no method matching SciMLBase.PDETimeSeriesSolution(::SciMLBase.PDETimeSeriesSolution{Float64, 1, Dict{Num, Matrix{Float64}}, MethodOfLines.MOLMetadata{Val{true}()
for Zygote.jacobian(result, [3.,])
Here is the code I used:
#This is a test file to see if Zygote/ForwardDiff and MoL like eachother
using ModelingToolkit, MethodOfLines, OrdinaryDiffEq, DomainSets, Zygote
using ChainRulesCore
using SciMLSensitivity
using ForwardDiff
# Parameters, variables, and derivatives
@parameters t x α
@variables u(..)
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq = Dt(u(t, x)) ~ α * Dxx(u(t, x)) # Heat Convection in a pipe
bcs = [u(0, x) ~ 0.,
u(t, 0) ~ 30.,
u(t, 1) ~ 0.]
# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], [α => 1.])
# Method of lines discretization
dx = 0.1
dt = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)
# Convert the PDE problem into an ODE problem
prob = discretize(pdesys,discretization)
sol = solve(prob, Tsit5(), p=[1.], dt=dt, saveat=0.2)
function result(p)
# Solve ODE problem
sol = solve(prob, Tsit5(), p=p, saveat=0.2)
sol.u[u(t,x)]
end
result([3.]) # First pass
function ChainRulesCore.rrule(::Type{<:SciMLBase.PDETimeSeriesSolution}, sol, meta)
function PDETimeSeriesSolutionAdjoint(ȳ)
(ChainRulesCore.NoTangent(), ȳ, ChainRulesCore.NoTangent())
end
SciMLBase.PDETimeSeriesSolution(sol, meta), PDETimeSeriesSolutionAdjoint
end
Zygote.@adjoint function Zygote.literal_getproperty(sol::DiffEqBase.AbstractTimeseriesSolution,
::Val{:u})
solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
_Δ = @. ifelse(Δ == nothing, (zerou,), Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ), nothing)
end
sol.u, solu_adjoint
end
ForwardDiff.jacobian(result, [3.,])
Zygote.jacobian(result, [3.,])
We have other examples of MethodOfLines.jl working fine now, so we can close this. The remaining issue is https://github.com/SciML/SciMLSensitivity.jl/issues/760 but that's its own issue.
Could you link these working examples of MethodOfLines with Zygote? Thanks 🙂
@xtalax ?
Code:
With ForwardDiff:
It works fine (Though the array is the wrong shape, 66x1 instead of 6x11), and the result appears to make sense.
With Zygote:
crashes with: