SciML / MethodOfLines.jl

Automatic Finite Difference PDE solving with Julia SciML
https://docs.sciml.ai/MethodOfLines/stable/
MIT License
157 stars 27 forks source link

Zygote not working for gradient w.r.t. parameters (with remake) #309

Open Qfl3x opened 12 months ago

Qfl3x commented 12 months ago

MWE:

using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets

using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters t x
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
prob = discretize(pdesys,discretization)

function pde_solution(ps)
    ps = [α => ps[1], β => ps[2]]
    _prob = remake(prob, p=ps)
    sum(solve(_prob, Tsit5(), saveat=0.1)[u(t,x)][end,:])
end

ADZyg = AD.ZygoteBackend()
grad = AD.gradient(ADZyg, pde_solution, rand(2))

I used AbstractDiff above, but Zygote alone gives the same error:

ERROR: MethodError: no method matching size(::IRTools.Inner.Undefined)

Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:580
  size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:584
  ...

Stacktrace:
  [1] axes(A::IRTools.Inner.Undefined)
    @ Base ./abstractarray.jl:98
  [2] _tryaxes(x::IRTools.Inner.Undefined)
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:188
  [3] map
    @ ./tuple.jl:274 [inlined]
  [4] adjoint
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:322 [inlined]
  [5] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
  [6] _pullback
    @ ./iterators.jl:370 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(zip), ::IRTools.Inner.Undefined, ::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [8] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:635 [inlined]
  [9] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::IRTools.Inner.Undefined)
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [10] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/variables.jl:149 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.process_p_u0_symbolic), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, ::Vector{Pair{Num, Float64}}, ::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [12] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:78 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#617", ::Missing, ::Missing, ::Missing, ::Vector{Pair{Num, Float64}}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [14] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:52 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:p,), Tuple{Vector{Pair{Num, Float64}}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [16] _pullback
    @ ~/GitProjects/hyco/scripts/mwe.jl:40 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::typeof(pde_solution), args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [18] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
 [19] pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
 [20] gradient(f::Function, args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
 [21] top-level scope
    @ REPL[7]:1
Qfl3x commented 11 months ago

So I've modified remake in SciMLBase by essentially removing the symbolic checks here: https://github.com/SciML/SciMLBase.jl/blob/75605b1a8754bc4452761100589a6020b8e9f035/src/remake.jl#L71-L79

Once that is done, I get a new error:

ERROR: No matching function wrapper was found!

that starts at:

  [7] (::ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem})(::Vector{Float64}, ::Vararg{Any})
    @ SciMLBase /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/scimlfunctions.jl:2267
Qfl3x commented 11 months ago

FunctionWrappersWrappers's first non-_call stacktrace:

  [6] (::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false})(::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64}, ::Float64)
    @ FunctionWrappersWrappers /Net/Groups/BGI/people/mchettouh/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10

I didn't put the entire stacktrace as it's very long.

BernhardAhrens commented 11 months ago

Is linked to this issue https://github.com/SciML/SciMLSensitivity.jl/issues/794

BernhardAhrens commented 11 months ago

Could it be a workaround to avoid wrap altogether with wrap = Val(false) keyword argument to solve?

https://docs.sciml.ai/MethodOfLines/stable/solutions/#Original-solution

Qfl3x commented 11 months ago

@BernhardAhrens It's currently crashing at remake, it hasn't reached solve yet.

Qfl3x commented 11 months ago

I've tried to pinpoint the error more clearly and got this new MWE:

using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets

using PDEBase: add_metadata!
using ModelingToolkit: get_metadata

using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan; discretization.kwargs...)

function remake_p(prob::ODEProblem, p; simpsys=simpsys)
    tspan = prob.tspan
    u0 = prob.u0

    _f = prob.f
    ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, u0, tspan, p)#, prob.problem_type )
end

function pde_solution2(ps)
    ps = [α => ps[1], β => ps[2]]
    _prob = remake_p(prob, ps)
    sol = solve(_prob, Tsit5(), saveat=0.1)
    return sum(sol[u(t,x)][end,:])
end

pde_solution2([1.2,.3]);
ADzyg = AD.ZygoteBackend()
AD.gradient(ADzyg, pde_solution2, rand(2))

This one avoids some of the Zygote errors (Forces FullSpecialize, simplified remake), but it still crashes with:


ERROR: Compiling Tuple{Type{Dict}, Vector{Pair{Num, Float64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] macro expansion
    @ home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Type{Dict}, args::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101
  [3] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:633 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Pair{Num, Float64}}, ::IRTools.Inner.Undefined)
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [5] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:694 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##get_u0_p#578", ::Bool, ::Bool, ::Bool, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [7] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:684 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:tofloat, :use_union, :symbolic_u0), Tuple{Bool, Bool, Bool}}, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [9] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:724 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##process_DEProblem#579", ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Symbolics.SerialForm, ::Bool, ::Bool, ::Bool, ::Bool, ::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [11] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:707 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [13] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:834 [inlined]
 [14] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##_#586", ::Nothing, ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [15] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:827 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
.... 

The crash is happening in the stdlib here: https://github.com/JuliaLang/julia/blob/b74daf501619ac4be061c67d80608c4c8822fc36/base/dict.jl#L114-L126


function Dict(kv)
    try
        dict_with_eltype((K, V) -> Dict{K, V}, kv, eltype(kv))
    catch
        if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
            throw(ArgumentError("Dict(kv): kv needs to be an iterator of tuples or pairs"))
        else
            rethrow()
        end
    end
end
Qfl3x commented 11 months ago

(Given up on remake)

After following Chris's advice here, calling solve with parameters, and using wrap=Val(false), I've reached error-parity between Zygote and ReverseDiff here: https://github.com/SciML/RecursiveArrayTools.jl/blob/0965fc1f69424b9623f1150221d64889185189a3/src/vector_of_array.jl#L113

ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved
Qfl3x commented 11 months ago

Can't figure out a way to add a breakpoint inside the gradient calculation, would be nice if someone had an idea.

xtalax commented 11 months ago

Gradient errors are really hard to debug, maybe @ChrisRackauckas knows a strategy?

Qfl3x commented 11 months ago

I've abandoned the previous way (I was extending too many functions in std lib that weren't supposed to be extended), instead looking for a working example in ModelingToolkit to see where it's going wrong. This Lorenz system example works fine (Forced FullSpecialize to make sure they're both Specialized similarly). The types of both ODEProblems are:

MethodOfLines:

ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x558fcb15, 0xbc5f765b, 0x5dbef508, 0xc5018616, 0x76ff7ee2), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}

ModelingToolkit:

ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xdf8f3253, 0x53462d65, 0x6e40626a, 0xb391e8a4, 0x9f84fe2a), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xfad0093e, 0xafa14235, 0x0e0713c0, 0x13e5b2a3, 0x1da971c5), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#___jac#555"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xc7653a39, 0xff0cdb6a, 0x7e27b429, 0x0d9163ea, 0x97ee01d3), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xd983ef08, 0x57e02881, 0x3cbc6046, 0x1441c5e4, 0xfda2cffd), Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}

The problem appears to originate from MethodOfLines somehow, then propagates to give a weird stacktrace elsewhere in the ecosystem.

Qfl3x commented 11 months ago

Current closest to working mwe:

using ModelingToolkit, MethodOfLines, DomainSets, OrdinaryDiffEq

using PDEBase: add_metadata!
using ModelingToolkit: get_metadata

using Zygote
using ReverseDiff
import AbstractDifferentiation as AD
using SciMLSensitivity
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob= ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan)
param_vars = [α, β]
idxs = ModelingToolkit.varmap_to_vars([param_vars[1] => 1, param_vars[2] => 2], param_vars)

test_p = [1.2, 1.4]
test_p[Int.(idxs)]
function pde_solution2(ps)
    #_prob = remake_p(prob, ps)
    ps = ps[Int.(idxs)]
    sol = solve(prob, Tsit5(), saveat=0.1, p=ps, wrap=Val(false))
    return sum(sol.u[1])
end

# using Zygote
pde_solution2([1.2,.3])
ADzyg = AD.ZygoteBackend()
function grad(ps)
    AD.gradient(ADzyg, pde_solution2, ps)
end
grad(rand(2))

crashes at sum(sol.u[1]), so it actually finishes getting the solution. Note that with ReverseDiff it crashes in the solver. From the many attempts I made it seems that sol.u is somehow a Dictionary with Matrices and running getindex is crashing it. Interpolation doesn't work either (even without AD).

Infiltrator isn't helping me much as Zygote is trying to compile the macro and crashes.

xtalax commented 11 months ago

the sol interface is different for pdes, if you don't care about shape pass wrap = Val(false) to the solve and this will work

Qfl3x commented 11 months ago

I'm already passing wrap=Val(false), without it it crashes sooner.

xtalax commented 11 months ago

What is the type of sol?

Qfl3x commented 11 months ago
ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.Stats, Nothing}