SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 154 forks source link

How to treat parameters in a model data struture? #520

Closed MartinOtter closed 3 years ago

MartinOtter commented 3 years ago

DifferentialEquations.jl and DiffEqFlux.jl have the function interface:

function f!(du, u, p, t) ... end

When using DifferentialEquations.jl, p can be anything, especially any type of data structure, e.g. a hierarchical struct.

It seems that DiffEqFlux.jl assumes that p is an array where all elements shall be optimized.

In realistic applications, there are often also other data that needs to be passed to the function in order to compute the derivative du. Also other auxiliary variables might need to be computed and stored. The question is how to handle this with DiffEqFlux.jl?

The "copy paste" code in https://diffeqflux.sciml.ai/stable/examples/optimization_ode/ is slightly modified to demonstrate the issue:

# original code
function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
...

Assume that there is a data structure that holds all data and variable values needed to evaluate the model, such as:

# code with a more general model data structure
mutable struct MyModel{FloatType}
    p::Vector{FloatType}  # to be optimized
    p2::Float64   # other data, not to be optimized
    aux::Vector{FloatType}  # other variables to be computed
end

# that is used in the model function
function lotka_volterra_with_struct!(du, u, m, t)
  p = m.p   # and m.p2, m.aux are also used (but ignored here for simplicity)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
myModel = MyModel([1.5, 1.0, 3.0, 1.0], 2.0, zeros(2))
prob = ODEProblem(lotka_volterra!, u0, tspan, myModel)

If the remaining code part in https://diffeqflux.sciml.ai/stable/examples/optimization_ode/ remains unchanged, there will be a compile-time error when executing the code. What to do?

One solution is (but this is not nice at all):

OptFloatType = ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
p = [1.5, 1.0, 3.0, 1.0]
const myModel1 = MyModel(p, 2.0, zeros(2))
const myModel2 = MyModel{OptFloatType}(p, 2.0, zeros(2))

function lotka_volterra!(du, u, p::Vector{Float64}, t)
    myModel1.p .= p
    lotka_volterra_with_struct!(du, u, myModel1, t) 
end

function lotka_volterra!(du, u, p, t)
    myModel2.p .= p
    lotka_volterra_with_struct!(du, u, myModel2, t) 
end

prob = ODEProblem(lotka_volterra!, u0, tspan, p)

That is, using two instances of the struct with different FloatTypes and providing two different functions that copy the vector to be optimized to myModel, depending on the data type of p.

This optimization setup is working and produces the same results as in the DiffEqFlux docu.

Is there a better solution available with DiffEqFlux?

ChrisRackauckas commented 3 years ago

This is a duplicate of https://github.com/SciML/DiffEqFlux.jl/issues/178 which is a mostly upstream issue.