SciML / SymbolicIndexingInterface.jl

A general interface for symbolic indexing of SciML objects used in conjunction with Domain-Specific Languages
https://docs.sciml.ai/SymbolicIndexingInterface/stable/
MIT License
14 stars 6 forks source link

How to use setp with ForwardDiff in cost function? #59

Closed bgctw closed 5 months ago

bgctw commented 5 months ago

How to use setp with ForwardDiff in cost function?

When I update parameters inside a cost function with a setter that I obtained by setp, I get errors when trying to apply ForwardDiff.gradient to the cost function:

What do I wrong in the following MWE?

Pkg.activate(;temp=true)
Pkg.add(["ModelingToolkit","SymbolicIndexingInterface","ComponentArrays","ForwardDiff"])
#Pkg.develop(url="https://github.com/SciML/SymbolicIndexingInterface.jl")
#Pkg.develop("ModelingToolkit")

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using SymbolicIndexingInterface: setp, SymbolicIndexingInterface as SII
using ComponentArrays: ComponentArrays as CA
using ForwardDiff: ForwardDiff

u1 = CA.ComponentVector(L = 10.0)
p1 = CA.ComponentVector(k_L = 1.0, k_R = 1 / 20, m = 2.0)
popt = CA.ComponentVector(par=(m = 4.0,),)
function get_sys1()
    sts = @variables L(t)
    ps = @parameters k_L, k_R, m 
    eq = [D(L) ~ 0, ]
    ODESystem(eq, t, sts, vcat(ps...); name=:sys1)
end
sys1 = structural_simplify(get_sys1())

_pmap = Dict( keys(p1) .=> collect(values(p1)) )
prob = ODEProblem(sys1, collect(u1), (0.0,1.1), _pmap)

setter! = SII.setp(sys1, [sys1.m])
setter!(prob, [popt.par.m])

tmp = (popt) -> begin
    setter!(prob, popt)
    d = prob.ps[sys1.m]
    d*d
end
tmp([popt.par.m])
res = ForwardDiff.gradient(tmp, [popt.par.m])

resulting in error

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  Float64(::IrrationalConstants.Invsqrt2)
   @ IrrationalConstants ~/scratch/twutz/julia_cluster_depots/packages/IrrationalConstants/vp5v4/src/macro.jl:112
  ...

Stacktrace:
  [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:1021
  [3] set_parameter!(p::ModelingToolkit.MTKParameters{…}, val::ForwardDiff.Dual{…}, idx::ModelingToolkit.ParameterIndex{…})
    @ ModelingToolkit ~/julia/dev/ModelingToolkit/src/systems/parameter_buffer.jl:247
  [4] set_parameter!(sys::ODEProblem{…}, val::ForwardDiff.Dual{…}, idx::ModelingToolkit.ParameterIndex{…})
    @ SymbolicIndexingInterface ~/julia/dev/SymbolicIndexingInterface/src/parameter_indexing.jl:84
  [5] (::SymbolicIndexingInterface.var"#setter!#49"{…})(sol::ODEProblem{…}, val::ForwardDiff.Dual{…})
    @ SymbolicIndexingInterface ~/julia/dev/SymbolicIndexingInterface/src/parameter_indexing.jl:251
  [6] (::SymbolicIndexingInterface.var"#53#55"{…})(s!::SymbolicIndexingInterface.var"#setter!#49"{…}, v::ForwardDiff.Dual{…})
    @ SymbolicIndexingInterface ~/julia/dev/SymbolicIndexingInterface/src/parameter_indexing.jl:263
  [7] (::Base.var"#4#5"{SymbolicIndexingInterface.var"#53#55"{…}})(a::Tuple{SymbolicIndexingInterface.var"#setter!#49"{…}, ForwardDiff.Dual{…}})
    @ Base ./generator.jl:36
  [8] iterate
    @ Base ./generator.jl:47 [inlined]
  [9] collect
    @ Base ./array.jl:834 [inlined]
 [10] map
    @ Base ./abstractarray.jl:3406 [inlined]
 [11] (::SymbolicIndexingInterface.var"#setter!#54"{…})(sol::ODEProblem{…}, val::Vector{…})
    @ SymbolicIndexingInterface ~/julia/dev/SymbolicIndexingInterface/src/parameter_indexing.jl:263
 [12] (::var"#3#4")(popt::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1}})
    @ Main ./REPL[30]:2
 [13] vector_mode_dual_eval!
    @ ~/scratch/twutz/julia_cluster_depots/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [14] vector_mode_gradient(f::var"#3#4", x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1, Vector{ForwardDiff.Dual{…}}})
    @ ForwardDiff ~/scratch/twutz/julia_cluster_depots/packages/ForwardDiff/PcZ48/src/gradient.jl:89
 [15] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1, Vector{ForwardDiff.Dual{…}}}, ::Val{true})
    @ ForwardDiff ~/scratch/twutz/julia_cluster_depots/packages/ForwardDiff/PcZ48/src/gradient.jl:19
 [16] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4", Float64}, Float64, 1, Vector{ForwardDiff.Dual{…}}})
    @ ForwardDiff ~/scratch/twutz/julia_cluster_depots/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [17] top-level scope
    @ REPL[32]:1

using Julia 1.10 and package versions

[b0b7db55] ComponentArrays v0.15.11
  [f6369f11] ForwardDiff v0.10.36
  [961ee093] ModelingToolkit v9.6.1 `~/julia/dev/ModelingToolkit`
  [2efcf032] SymbolicIndexingInterface v0.3.11 `~/julia/dev/SymbolicIndexingInterface`
AayushSabharwal commented 5 months ago

The problem here is that the parameter you're updating is stored in a Vector{Float64}, but since you're trying to find gradients using ForwardDiff it tries to set it with a ForwardDiff.Dual, which understandably throws that error. @ChrisRackauckas how do we want to support this workflow? Should there be a function that converts all the stored parameters to duals?

AayushSabharwal commented 5 months ago

Also this is really an MTK thing, so maybe the issue should be transferred there? I can't do it since I don't have permissions for MTK

ChrisRackauckas commented 5 months ago

This should be using remake. Mutation will never be the right answer if you're changing types.

AayushSabharwal commented 5 months ago

This won't work with remake either since that just copys the MTKParameters object and calls setp

hersle commented 5 months ago

I just ran into this issue. Here is another (simpler) example using remake, where I try to take the derivative of a DE solution that depends on 1 initial value and 1 parameter. I (naively?) expect this code to work:

using ModelingToolkit
using DifferentialEquations
using ForwardDiff

# an ODEProblem with 1 initial value and 1 parameter
@parameters k
@variables x y(x)
Dx = Differential(x)
@mtkbuild sys = ODESystem([Dx(y) ~ k], x)
prob0 = ODEProblem(sys, [y => NaN], (0, 1), [k => NaN]) # to-be-remade uninitialized problem

# a solution that depends on 1 initial value and 1 parameter
function f(y0val, kval)
    prob = remake(prob0; u0=[y => y0val], p=[k => kval])
    return solve(prob)[y][end] # == y(1)
end

f(1,2) # evaluate (works)
ForwardDiff.derivative(z -> f(z,2), 1) # derivative with respect to initial value (fails)
ForwardDiff.derivative(z -> f(1,z), 2) # derivative with respect to parameter (fails)

The last two lines fail with the same error as reported here. The derivative wrt the initial value can be fixed by substituting u0=[y => y0val] to u0=ModelingToolkit.varmap_to_vars([y => y0val]. But I can't find a way to get the derivative wrt the parameter work.

Intuitively, I would expect these two cases to work symmetrically, and for the above code to work without modification.

These are just my thoughts as an "end-user" 😃

ChrisRackauckas commented 5 months ago

I would have expected this code to work too. Seems there is a promotion that is missing.

ChrisRackauckas commented 5 months ago

Duplicate of https://github.com/SciML/ModelingToolkit.jl/issues/2571