SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

ForwardDiff works, but Zygote doesn't with my simple PDE. (MethodOfLines) #794

Closed Qfl3x closed 1 year ago

Qfl3x commented 1 year ago

Code:

#This is a test file to see if Zygote/ForwardDiff and MoL like eachother
using ModelingToolkit, MethodOfLines, OrdinaryDiffEq, DomainSets, Zygote

using ForwardDiff

    # Parameters, variables, and derivatives
    @parameters t x α
    @variables u(..)
    Dt = Differential(t)
    Dxx = Differential(x)^2

    # 1D PDE and boundary conditions
    eq  = Dt(u(t, x)) ~ α * Dxx(u(t, x)) # Heat Convection in a pipe
    bcs = [u(0, x) ~ 0.,
    u(t, 0) ~ 30.,
    u(t, 1) ~ 0.]

    # Space and time domains
    domains = [t ∈ Interval(0.0, 1.0),
    x ∈ Interval(0.0, 1.0)]

    # PDE system
    @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], [α => 1.])

    # Method of lines discretization
    dx = 0.1
    dt = 0.1
    order = 2
    discretization = MOLFiniteDifference([x => dx], t)

    # Convert the PDE problem into an ODE problem
    prob = discretize(pdesys,discretization)
    sol = solve(prob, Tsit5(), p=[1.], dt=dt, saveat=0.2)

function result(p)
    # Solve ODE problem
    sol = solve(prob, Tsit5(), p=p, saveat=0.2)
    sol.u[u(t,x)]
end

result([3.]) # First pass

With ForwardDiff:

ForwardDiff.jacobian(result, [3.,])

It works fine (Though the array is the wrong shape, 66x1 instead of 6x11), and the result appears to make sense.

With Zygote:

Zygote.jacobian(result, [3.,])

crashes with:

ERROR: MethodError: no method matching SciMLBase.PDETimeSeriesSolution(::SciMLBase.PDETimeSeriesSolution{Float64, 1, Dict{Num, Matrix{Float64}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, OrdinaryDiffEq.InterpolationData{ODEFunction{true,SciMLBase.AutoSpecialize,…}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Tsit5Cache{Vector{Float64},…}}, DiffEqBase.DEStats, Nothing}, Nothing, Vector{Float64}, Tuple{Vector{Float64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Vector{SymbolicUtils.BasicSymbolic{Real}}, Vector{Num}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, Dict{Num, Interpolations.GriddedInterpolation{Float64, 2, Matrix{Float64}, Interpolations.Gridded{Interpolations.Linear{Interpolations.Throw{Interpolations.OnGrid}}}, Tuple{Vector{Float64}, Vector{Float64}}}}}, ::MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization})

Some of the types have been truncated in the stacktrace for improved reading. To emit complete information
in the stack trace, evaluate `TruncatedStacktraces.VERBOSE[] = true` and re-run the code.

Closest candidates are:
  SciMLBase.PDETimeSeriesSolution(::SciMLBase.AbstractODESolution{T}, ::MethodOfLines.MOLMetadata) where T at ~/.julia/packages/MethodOfLines/44Kiz/src/interface/solution/timedep.jl:2
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
  ... 
DhairyaLGandhi commented 1 year ago

After defining

function ChainRulesCore.rrule(::Type{<:SciMLBase.PDETimeSeriesSolution}, sol, meta)
    function PDETimeSeriesSolutionAdjoint(ȳ)
        (ChainRulesCore.NoTangent(), ȳ, ChainRulesCore.NoTangent())
    end
    SciMLBase.PDETimeSeriesSolution(sol, meta), PDETimeSeriesSolutionAdjoint
end

Zygote.@adjoint function Zygote.literal_getproperty(sol::DiffEqBase.AbstractTimeseriesSolution,
                                                              ::Val{:u})
    solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])
    function solu_adjoint(Δ)
        zerou = zero(sol.prob.u0)
        _Δ = @. ifelse(Δ == nothing, (zerou,), Δ)
        (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ), nothing)
    end
    sol.u, solu_adjoint
end

gives

ERROR: DimensionMismatch("variable with size(x) == (9, 6) cannot have a gradient with size(dx) == (66,)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})(dx::ODESolution{Float64, 2, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, SciMLBase.LinearInterpolation{Vector{Float64}, Matrix{Float64}}, Nothing, Nothing})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:227
 [2] _project
   @ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:189 [inlined]
 [3] map(f::typeof(Zygote._project), t::Tuple{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, OrdinaryDiffEq.InterpolationData{ODEFunction{true,SciMLBase.AutoSpecialize,…}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Tsit5Cache{Vector{Float64},…}}, DiffEqBase.DEStats, Nothing}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, s::Tuple{ODESolution{Float64, 2, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, SciMLBase.LinearInterpolation{Vector{Float64}, Matrix{Float64}}, Nothing, Nothing}, Nothing})
   @ Base ./tuple.jl:247
 [4] gradient(::Function, ::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{true,Vector{Float64},Tuple{Float64, Float64},…}, Tsit5{Static.False,…}, OrdinaryDiffEq.InterpolationData{ODEFunction{true,SciMLBase.AutoSpecialize,…}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Tsit5Cache{Vector{Float64},…}}, DiffEqBase.DEStats, Nothing}, ::Vararg{Any})
   @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:98
 [5] top-level scope
   @ REPL[13]:1

It is curious that we call

SciMLBase.wrap_sol(::PDETimeSeriesSolution, metadata)

at all.

ChrisRackauckas commented 1 year ago

It is curious that we call

Why? It's required for wrapping the PDE solution.

solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])

What is this one for?

The last piece of the puzzle may be the missing https://github.com/SciML/SciMLSensitivity.jl/issues/760

DhairyaLGandhi commented 1 year ago

Why? It's required for wrapping the PDE solution

It is forwarding to the PDETimeSeriesSolution constructor. Is it defined to accept a PDETimeSeriesSolution? I see it defined for AbstractODESolution

What is this one for?

The pullback you get in this case is a dictionary with the observed variables. It does seem like it's missing an adjoint for those.

ChrisRackauckas commented 1 year ago

It is forwarding to the PDETimeSeriesSolution constructor. Is it defined to accept a PDETimeSeriesSolution? I see it defined for AbstractODESolution

Oh yes, it's never called on a PDE solution.

BernhardAhrens commented 1 year ago

I tried to find a workaround for this by using the keyword argument wrap = Val(false) in solve to get the original ODE solution back, compare here `https://docs.sciml.ai/MethodOfLines/stable/solutions/#Original-solution

Then, just summing the solution and taking the gradient works with ForwardDiff.gradient, but not with Zygote.gradient which errors with

MethodError: no method matching +(::Num, ::Matrix{Float64})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array

I changed the last lines to

function result(p)
    # Solve ODE problem
    sol = solve(prob, Tsit5(), p=p, saveat=0.2; wrap = Val(false))
    sum(sum(sol.u))
end

result([3.]) # First pass

ForwardDiff.gradient(result, [3.,])
Zygote.gradient(result, [3.,])

I don't quite get what is going wrong.

BernhardAhrens commented 1 year ago

Dear @DhairyaLGandhi, could you share if you have made any modifications to the code in the original post. When I add your lines, I still get

ERROR: MethodError: no method matching SciMLBase.PDETimeSeriesSolution(::SciMLBase.PDETimeSeriesSolution{Float64, 1, Dict{Num, Matrix{Float64}}, MethodOfLines.MOLMetadata{Val{true}()

for Zygote.jacobian(result, [3.,])

Here is the code I used:

#This is a test file to see if Zygote/ForwardDiff and MoL like eachother
using ModelingToolkit, MethodOfLines, OrdinaryDiffEq, DomainSets, Zygote
using ChainRulesCore
using SciMLSensitivity
using ForwardDiff

# Parameters, variables, and derivatives
@parameters t x α
@variables u(..)
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~ α * Dxx(u(t, x)) # Heat Convection in a pipe
bcs = [u(0, x) ~ 0.,
u(t, 0) ~ 30.,
u(t, 1) ~ 0.]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]

# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], [α => 1.])

# Method of lines discretization
dx = 0.1
dt = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
prob = discretize(pdesys,discretization)
sol = solve(prob, Tsit5(), p=[1.], dt=dt, saveat=0.2)

function result(p)
    # Solve ODE problem
    sol = solve(prob, Tsit5(), p=p, saveat=0.2)
    sol.u[u(t,x)]
end

result([3.]) # First pass

function ChainRulesCore.rrule(::Type{<:SciMLBase.PDETimeSeriesSolution}, sol, meta)
    function PDETimeSeriesSolutionAdjoint(ȳ)
        (ChainRulesCore.NoTangent(), ȳ, ChainRulesCore.NoTangent())
    end
    SciMLBase.PDETimeSeriesSolution(sol, meta), PDETimeSeriesSolutionAdjoint
end

Zygote.@adjoint function Zygote.literal_getproperty(sol::DiffEqBase.AbstractTimeseriesSolution,
                                                              ::Val{:u})
    solu_adjoint(Δ::Dict) = solu_adjoint(Δ[first(keys(sol.u))])
    function solu_adjoint(Δ)
        zerou = zero(sol.prob.u0)
        _Δ = @. ifelse(Δ == nothing, (zerou,), Δ)
        (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ), nothing)
    end
    sol.u, solu_adjoint
end

ForwardDiff.jacobian(result, [3.,])

Zygote.jacobian(result, [3.,])
ChrisRackauckas commented 1 year ago

We have other examples of MethodOfLines.jl working fine now, so we can close this. The remaining issue is https://github.com/SciML/SciMLSensitivity.jl/issues/760 but that's its own issue.

BernhardAhrens commented 1 year ago

Could you link these working examples of MethodOfLines with Zygote? Thanks 🙂

ChrisRackauckas commented 1 year ago

@xtalax ?