SciML / SymbolicIndexingInterface.jl

A general interface for symbolic indexing of SciML objects used in conjunction with Domain-Specific Languages
https://docs.sciml.ai/SymbolicIndexingInterface/stable/
MIT License
14 stars 6 forks source link

Make `getp` AD friendly #77

Open DhairyaLGandhi opened 3 months ago

DhairyaLGandhi commented 3 months ago

Describe the bug 🐞

Sister to #69

Current plan is for a major rewrite to getp to handle hybrid continuous and discrete systems by @AayushSabharwal , and in the process we are going to track the behaviour of the getp function to make sure it ADs properly, considering there are known bugs with how it currently works.

This issue is to show the current situation, and track any progress. We also need to add tests related to AD for getp/u in SciMLBase/ SII to ensure the behaviour is tracked over time.

Expected behavior

Gradients produced should be correct.

Minimal Reproducible Example 👇


@parameters σ ρ β A2[1:10, 1:10]
@variables x(t) y(t) z(t) w(t) w2(t)
# @variables A[1:10, 1:10]

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x * (ρ - z) - y,
    D(z) ~ x * y - β * z,
    w ~ x + y + z + 2 * β,]

@mtkbuild sys = ODESystem(eqs, t)

ModelingToolkit.observed(sys)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p = [σ => 28.0,
    ρ => 10.0,
    β => 8 / 3,]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())

Error & Stacktrace ⚠️

julia> julia> pf = getp(sol, sys.β)
SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))

julia> gradient(sol) do sol # correct
         sum(pf(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf2 = getp(sol, [sys.β]) 
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])

julia> gradient(sol) do sol # still correct, note single element in vector
         sum(pf2(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf3 = getp(sol, [sys.β, sys.β])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))
), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])

julia> gradient(sol) do sol # incorrect, should be [0.0, 2.0, 0.0]
         sum(pf3(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 3.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf4 = getp(sol, [sys.β, sys.ρ])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}
}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 1)))])

julia> gradient(sol) do sol # incorrect, should be [1.0, 1.0, 0.0]
         sum(pf4(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([2.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

The gradients returned when passed a vector of parameters is incorrect.

Environment (please complete the following information):

```julia (SciMLSensitivity) pkg> st [25/1807] Project SciMLSensitivity v7.57.0 Status `~/arpa/jsmo/clone/SciMLSensitivity.jl/Project.toml` ⌅ [47edcb42] ADTypes v0.2.7 [79e6a3ab] Adapt v4.0.4 [4fba245c] ArrayInterface v7.10.0 [082447d4] ChainRules v1.66.0 `https://github.com/JuliaDiff/ChainRules.jl.git#main` [d360d2e6] ChainRulesCore v1.23.0 [2b5f629d] DiffEqBase v6.149.1 `https://github.com/DhairyaLGandhi/DiffEqBase.jl.git#dg/kw` [459566f4] DiffEqCallbacks v3.6.2 [77a26b50] DiffEqNoiseProcess v5.21.0 [31c24e10] Distributions v0.25.108 [da5c29d0] EllipsisNotation v1.8.0 [7da242da] Enzyme v0.12.6 [6a86dc24] FiniteDiff v2.23.1 [f6369f11] ForwardDiff v0.10.36 [f62d2435] FunctionProperties v0.1.2 [77dc65aa] FunctionWrappersWrappers v0.1.3 [d9f16b24] Functors v0.4.10 [46192b85] GPUArraysCore v0.1.6 [7ed4a6bd] LinearSolve v2.30.0 [961ee093] ModelingToolkit v9.13.0 [1dea7af3] OrdinaryDiffEq v6.76.0 [d96e819e] Parameters v0.12.3 [d236fae5] PreallocationTools v0.4.21 [1fd47b50] QuadGK v2.9.4 [e6cf234a] RandomNumbers v1.5.3 [731186ca] RecursiveArrayTools v3.18.1 `../RecursiveArrayTools.jl` [189a3867] Reexport v1.2.2 [37e2e3b7] ReverseDiff v1.15.3 [0bca4576] SciMLBase v2.36.1 `https://github.com/DhairyaLGandhi/SciMLBase.jl#dg/obsfn` [c0aeaf25] SciMLOperators v0.3.8 [53ae85a6] SciMLStructures v1.2.0 ⌃ [47a9eef4] SparseDiffTools v2.18.0 [90137ffa] StaticArrays v1.9.3 [1e83bf80] StaticArraysCore v1.4.2 [789caeaf] StochasticDiffEq v6.65.1 [2efcf032] SymbolicIndexingInterface v0.3.21 [9f7883ad] Tracker v0.2.34 [781d530d] TruncatedStacktraces v1.4.0 [e88e6eb3] Zygote v0.6.70 [37e2e46d] LinearAlgebra [d6f4376e] Markdown [9a3f8284] Random [10745b16] Statistics v1.10.0 Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated` ```
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 64 × AMD EPYC 7513 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
  Threads: 1 on 64 virtual cores

Additional context

Add any other context about the problem here.