SciML / ModelingToolkit.jl

An acausal modeling framework for automatically parallelized scientific machine learning (SciML) in Julia. A computer algebra system for integrated symbolics for physics-informed machine learning and automated transformations of differential equations
https://mtk.sciml.ai/dev/
Other
1.41k stars 204 forks source link

DiffEqFlux fails with ReactionSystem, but not with system declared with a normal function #771

Closed TorkelE closed 3 years ago

TorkelE commented 3 years ago

Was trying to get parameter fitting to work for a Catalyst model. I ran through everything first with the Lotka Volterra and it all worked fine, but then when I exchanged the Lotka Volterra function for a Catalyst model it stops working. The error I get is a:

TypeError: in Mul, in T, expected T<:Number, got Type{SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}}}

This is a minimal example, using a ReactionSystem created directly:

# Fetch packages
using Catalyst, OrdinaryDiffEq, DiffEqFlux, Flux

# Decalre the model
@parameters A B t
@variables X(t) Y(t)
rxs = [Reaction(A, nothing, [X], nothing, [1])
       Reaction(1., [X,Y], [X], [2,1],[3])
       Reaction(B, [X], [Y], [1],[1])
       Reaction(1., [X], nothing, [1],nothing)]
brueeslator_MTK  = ReactionSystem(rxs, t, [X,Y], [A,B])

function brusselator_function(du, u, p, t)
  X, Y = u
  A, B = p
  du[1] = dx = A + 0.5Y*X^2 -B*X -X
  du[2] = dy = B*X - 0.5Y*X^2
end

u0 = [1.0, 1.0]
tspan = (0.0, 10.0)

Then if I optimise for the normal function:

function loss(p)
    sol = solve(ODEProblem(brusselator_function, u0,tspan,p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

it works fine.

But if I try the MTK one, it errors:

function loss(p)
    sol = solve(ODEProblem(brueeslator_MTK, [X => 1.0, Y => 1.0],tspan,[A => 1.0, B => 1.]), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

which produces a

TypeError: in Mul, in T, expected T<:Number, got Type{SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}}}

Stacktrace:
 [1] Mul at /home/torkelloman/.julia/packages/SymbolicUtils/EDgAP/src/types.jl:620 [inlined]
 [2] _pullback(::Zygote.Context, ::Type{SymbolicUtils.Mul}, ::Type{Real}, ::SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}}, ::Dict{Any,Number}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [3] adjoint at /home/torkelloman/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [4] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Type{T} where T, ::Tuple{DataType}, ::Tuple{SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}},Dict{Any,Number}}) at ./none:0
 [5] _pullback at /home/torkelloman/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [6] * at /home/torkelloman/.julia/packages/SymbolicUtils/EDgAP/src/types.jl:672 [inlined]
 [7] _pullback(::Zygote.Context, ::typeof(*), ::Float64, ::SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [8] / at /home/torkelloman/.julia/packages/SymbolicUtils/EDgAP/src/types.jl:682 [inlined]
 [9] _pullback(::Zygote.Context, ::typeof(/), ::SymbolicUtils.Mul{Real,Int64,Dict{Any,Number}}, ::Int64) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [10] #oderatelaw#382 at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:186 [inlined]
 [11] _pullback(::Zygote.Context, ::ModelingToolkit.var"##oderatelaw#382", ::Bool, ::typeof(oderatelaw), ::Reaction{Any,Int64}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0 (repeats 2 times)
 [12] #assemble_oderhs#383 at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:196 [inlined]
 [13] _pullback(::Zygote.Context, ::ModelingToolkit.var"##assemble_oderhs#383", ::Bool, ::typeof(ModelingToolkit.assemble_oderhs), ::ReactionSystem) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [14] #assemble_drift#386 at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:213 [inlined]
 [15] _pullback(::Zygote.Context, ::ModelingToolkit.var"##assemble_drift#386", ::Bool, ::Bool, ::typeof(ModelingToolkit.assemble_drift), ::ReactionSystem) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0 (repeats 2 times)
 [16] #convert#406 at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:382 [inlined]
 [17] _pullback(::Zygote.Context, ::ModelingToolkit.var"##convert#406", ::Bool, ::typeof(convert), ::Type{ODESystem}, ::ReactionSystem) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [18] convert at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:382 [inlined]
 [19] _pullback(::Zygote.Context, ::typeof(convert), ::Type{ODESystem}, ::ReactionSystem) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [20] #ODEProblem#410 at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:477 [inlined]
 [21] _pullback(::Zygote.Context, ::ModelingToolkit.var"##ODEProblem#410", ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Type{ODEProblem}, ::ReactionSystem, ::Array{Pair{Num,Float64},1}, ::Tuple{Float64,Float64}, ::Array{Pair{Num,Float64},1}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [22] adjoint at /home/torkelloman/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [23] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},UnionAll,ReactionSystem,Array{Pair{Num,Float64},1},Tuple{Float64,Float64},Array{Pair{Num,Float64},1}}, ::Tuple{}) at ./none:0
 [24] _pullback at /home/torkelloman/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [25] ODEProblem at /home/torkelloman/.julia/packages/ModelingToolkit/2gKdw/src/systems/reaction/reactionsystem.jl:477 [inlined]
 [26] _pullback(::Zygote.Context, ::Type{ODEProblem}, ::ReactionSystem, ::Array{Pair{Num,Float64},1}, ::Tuple{Float64,Float64}, ::Array{Pair{Num,Float64},1}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [27] loss at ./In[14]:2 [inlined]
 [28] _pullback(::Zygote.Context, ::typeof(loss), ::Array{Float64,1}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [29] #69 at /home/torkelloman/.julia/packages/DiffEqFlux/Bj6Is/src/train.jl:2 [inlined]
 [30] _pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss)}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [31] adjoint at /home/torkelloman/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [32] _pullback at /home/torkelloman/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [33] OptimizationFunction at /home/torkelloman/.julia/packages/SciMLBase/cjif9/src/problems/basic_problems.jl:107 [inlined]
 [34] _pullback(::Zygote.Context, ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [35] adjoint at /home/torkelloman/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [36] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Tuple{Array{Float64,1},SciMLBase.NullParameters}) at ./none:0
 [37] _pullback at /home/torkelloman/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [38] OptimizationFunction at /home/torkelloman/.julia/packages/SciMLBase/cjif9/src/problems/basic_problems.jl:107 [inlined]
 [39] _pullback(::Zygote.Context, ::OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#147#157"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#150#160"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#155#165",Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [40] adjoint at /home/torkelloman/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [41] _pullback at /home/torkelloman/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [42] #9 at /home/torkelloman/.julia/packages/GalacticOptim/OrrCM/src/solve.jl:133 [inlined]
 [43] _pullback(::Zygote.Context, ::GalacticOptim.var"#9#14"{OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#147#157"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#150#160"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#155#165",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}},Array{Float64,1},GalacticOptim.NullData}) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [44] pullback(::Function, ::Zygote.Params) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:167
 [45] gradient(::Function, ::Zygote.Params) at /home/torkelloman/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:48
 [46] __solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#147#157"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#150#160"{GalacticOptim.var"#146#156"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#155#165",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/torkelloman/.julia/packages/GalacticOptim/OrrCM/src/solve.jl:132
 [47] #solve#1 at /home/torkelloman/.julia/packages/GalacticOptim/OrrCM/src/solve.jl:46 [inlined]
 [48] sciml_train(::typeof(loss), ::Array{Float64,1}, ::ADAM, ::GalacticOptim.AutoZygote; kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}) at /home/torkelloman/.julia/packages/DiffEqFlux/Bj6Is/src/train.jl:5
 [49] top-level scope at In[14]:6

I get a similar error when I use a catalyst model:

brusselator_catalyst = @reaction_network begin
    A, ∅ → X
    1, 2X + Y → 3X
    B, X → Y
    1, X → ∅
end A B

and however I input u0/parameters:

function loss(p)
    sol = solve(ODEProblem(brusselator_catalyst, u0,tspan,p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

I get the same error if I try to convert to an ODESystem:

function loss(p)
    sol = solve(ODEProblem(convert(ODESystem,brueeslator_MTK), [X => 1.0, Y => 1.0],tspan,[A => 1.0, B => 1.]), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

or an ODEFunction

function loss(p)
    sol = solve(ODEProblem(ODEFunction(convert(ODESystem,brueeslator_MTK)), [X => 1.0, Y => 1.0],tspan,[A => 1.0, B => 1.]), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
isaacsas commented 3 years ago

What happens if you create the ODEProblem using Catalyst outside the loss function (maybe pass it as a parameter with a closure)?

ChrisRackauckas commented 3 years ago

Yeah I wouldn't differentiate the symbolic construction. Just use remake.

TorkelE commented 3 years ago

Maybe I misunderstood, but this also errors:

@parameters A B t
@variables X(t) Y(t)
rxs = [Reaction(A, nothing, [X], nothing, [1])
       Reaction(1., [X,Y], [X], [2,1],[3])
       Reaction(B, [X], [Y], [1],[1])
       Reaction(1., [X], nothing, [1],nothing)]
brueeslator_MTK  = ReactionSystem(rxs, t, [X,Y], [A,B])

brusselator_catalyst = @reaction_network begin
    A, ∅ → X
    1, 2X + Y → 3X
    B, X → Y
    1, X → ∅
end A B

function brusselator_function(du, u, p, t)
  X, Y = u
  A, B = p
  du[1] = dx = A + 0.5Y*X^2 -B*X -X
  du[2] = dy = B*X - 0.5Y*X^2
end

u0 = [1.0, 1.0]
tspan = (0.0, 10.0)

prob_MTK = ODEProblem(brueeslator_MTK,u0,tspan,[1.,1])
prob_func = ODEProblem(brusselator_function,u0,tspan,[1.,1])
function loss(p)
    sol = solve(remake(prob_MTK,p=p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

with probably the most gargantuan error message I've seen. It starts of something like

MethodError: ReverseDiff.TrackedReal{ForwardDiff.Dual{Forwa...

For reference, this works:

function loss(p)
    sol = solve(remake(prob_func,p=p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
isaacsas commented 3 years ago

I tried to play with this on the latest release, but it seems ReactionSystems are broken right now, see https://github.com/SciML/ModelingToolkit.jl/issues/779.

I was going to see if you could get it to work if you directly built an ODEFunction from brueeslator_MTK and then created the ODEProblem using just the rhs function within the ODEFuction. That would test if the problem is the generated rhs function, or something in the wrapping ODEProblem/ODEFunction.

TorkelE commented 3 years ago

Did the test, the error persists, which does give some hint:

mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
prob_MTK_odefun = ODEProblem(mtk_odefun,u0,tspan,[1.,1])
function loss(p)
    sol = solve(remake(prob_MTK_odefun,p=p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
MethodError: ReverseDiff.TrackedReal{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,ReverseDiff.GradientTape{DiffEqSensitivity.var"#75#84"{ODEFunction{true,ModelingToolkit.var"#f#225"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#554"), Symbol("##MTKArg#555"), Symbol("##MTKArg#556")),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",
...
isaacsas commented 3 years ago

Sorry I wasn't clear! I meant to use

mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])

which should only be using the generated ODE rhs from MTK.

TorkelE commented 3 years ago

ahh, so it would be:

mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])
function loss(p)
    sol = solve(remake(oprob,p=p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

(still, the same error appears though)

isaacsas commented 3 years ago

OK, that seems to suggest it is a build_function issue. @shashi any thoughts?

ChrisRackauckas commented 3 years ago

What about with Rosenbrock23(autodiff=false)? It would be good to simplify this as much as possible.

TorkelE commented 3 years ago

Now we are getting there! Yes, this works:

mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])
function loss(p)
    sol = solve(remake(oprob,p=p), Rosenbrock23(autodiff=false))
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
isaacsas commented 3 years ago

@TorkelE If you directly build the ODEs in MTK do you have this issue? (i.e. don't use ReactionSystem at all, but manually enter the symbolic ODEs.)

If it persists there it would point to some issue with build_function and AD types or at least ODESystems, if it goes away then that would seem to indicate it is an issue with how we are generating ODEs from ReactionSystems.

TorkelE commented 3 years ago

Might be already in the ODESystem, this generates an error:

using ModelingToolkit, OrdinaryDiffEq, DiffEqFlux, Flux 

@parameters t A B
@variables X(t) Y(t)
D = Differential(t)

eqs = [D(X) ~ A + 0.5*Y*X*X - B*X - X,
       D(Y) ~ B*X - 0.5*Y*X*X]

sys = ODESystem(eqs)
sys = ode_order_lowering(sys)

u0 = [X => 1.0, Y => 1.0]
p  = [A => 1.0, B => 1.0]
tspan = (0.0, 10.0)

prob_odesys = ODEProblem(ODESystem(eqs),u0,tspan,p)

function loss(p)
    sol = solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
ChrisRackauckas commented 3 years ago

Is it a Catalyst issue? Or just an ODESystem issue? This may boil down to being a RuntimeGeneratedFunction issue with reverse mode AD?

ChrisRackauckas commented 3 years ago

solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))

You mean false?

isaacsas commented 3 years ago

This seems like an issue with AD types in the generated rhs function. So maybe it’s a RuntimeGeneratedFunction error. I don’t think there are issues when using finite differences in the solvers.

isaacsas commented 3 years ago

@TorkelE seems to have shown this is not an issue related to Catalyst or ReactionSystems.

ChrisRackauckas commented 3 years ago

But ForwardDiff is fine? Was that isolated? Sounds like a ReverseDiff issue to me given how it holds functions. Can you remove the involvement of ReverseDiff. So test:

solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=false),sensealg=InterpolatingAdjoint(autojacvec=false))

and see which errors?

isaacsas commented 3 years ago

I'll give it a shot later this afternoon on the lastest MTK master.

isaacsas commented 3 years ago

Hmm, I can't reproduce the reported errors. I tried the following two examples on the latest MTK master and it went through fine

using ModelingToolkit, OrdinaryDiffEq, DiffEqFlux, Flux 

@parameters t A B
@variables X(t) Y(t)
D = Differential(t)

eqs = [D(X) ~ A + 0.5*Y*X*X - B*X - X,
       D(Y) ~ B*X - 0.5*Y*X*X]

sys = ODESystem(eqs)
sys = ode_order_lowering(sys)

u0 = [X => 1.0, Y => 1.0]
p  = [A => 1.0, B => 1.0]
tspan = (0.0, 10.0)

prob_odesys = ODEProblem(ODESystem(eqs),u0,tspan,p)

function loss(p)
    sol = solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

@parameters A B t
@variables X(t) Y(t)
rxs = [Reaction(A, nothing, [X], nothing, [1])
       Reaction(1., [X,Y], [X], [2,1],[3])
       Reaction(B, [X], [Y], [1],[1])
       Reaction(1., [X], nothing, [1],nothing)]
brueeslator_MTK  = ReactionSystem(rxs, t, [X,Y], [A,B])
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)

prob_MTK = ODEProblem(brueeslator_MTK,u0,tspan,[1.,1])
function loss(p)
    sol = solve(remake(prob_MTK,p=p), Rosenbrock23())
    loss = sum(abs2, sol)       # (Was initially optimizing against some data, but this makes the code shorter...)
    return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)

@TorkelE what MTK version are you on? Have you tried the latest release or master?

TorkelE commented 3 years ago

Ok, this is getting a bit weird. I run your code, it still errors, this is the output of Pkg.status()

Status `~/Desktop/ParamEstimExample/Project.toml`
  [479239e8] Catalyst v6.4.0
  [2445eb08] DataDrivenDiffEq v0.5.4
  [aae7a2af] DiffEqFlux v1.31.0
  [1130ab10] DiffEqParamEstim v1.19.1
  [41bf760c] DiffEqSensitivity v6.40.0
  [0c46a032] DifferentialEquations v6.16.0
  [587475ba] Flux v0.11.6
  [23fbe1c1] Latexify v0.14.7
  [961ee093] ModelingToolkit v5.6.1 `~/.julia/dev/ModelingToolkit`
  [429524aa] Optim v1.2.3
  [91a5bcdd] Plots v1.10.2
  [731186ca] RecursiveArrayTools v2.11.0

I think that's the latest MTK version as well.

isaacsas commented 3 years ago

I was running in the the Project.toml from ModelingToolkit. I get

Project ModelingToolkit v5.6.1
Status `~/.julia/dev/ModelingToolkit/Project.toml`
  [4fba245c] ArrayInterface v3.1.1
  [864edb3b] DataStructures v0.18.9
  [2b5f629d] DiffEqBase v6.57.5
  [c894b116] DiffEqJump v6.13.0
  [b552c78f] DiffRules v1.0.2
  [31c24e10] Distributions v0.24.12
  [ffbed154] DocStringExtensions v0.8.3
  [615f187c] IfElse v0.1.0
  [2ee39098] LabelledArrays v1.5.0
  [23fbe1c1] Latexify v0.14.7
  [093fc24a] LightGraphs v1.3.5
  [1914dd2f] MacroTools v0.5.6
  [77ba4419] NaNMath v0.3.5
  [731186ca] RecursiveArrayTools v2.11.0
  [189a3867] Reexport v1.0.0
  [ae029012] Requires v1.1.2
  [7e49a35a] RuntimeGeneratedFunctions v0.5.1
  [1bc83da4] SafeTestsets v0.0.1
  [0bca4576] SciMLBase v1.7.3
  [efcf1570] Setfield v0.7.0
  [276daf66] SpecialFunctions v1.2.1
  [90137ffa] StaticArrays v1.0.1
  [d1185830] SymbolicUtils v0.8.2
  [a2a6695c] TreeViews v0.3.0
  [3a884ed6] UnPack v1.0.2
  [1986cc42] Unitful v1.5.0
  [8ba89e20] Distributed
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [2f01184e] SparseArrays
TorkelE commented 3 years ago

You would have to import DiffEqFlux and Flux to that? I will see if I can manage to run from that one as well.

isaacsas commented 3 years ago

I made a new environment and added them. This is what I get, I guess this doesn't show the version on the indirect references though (which may be installed in my global 1.5 environment and reused?):

  [aae7a2af] DiffEqFlux v1.32.0
  [961ee093] ModelingToolkit v5.6.1 `~/.julia/dev/ModelingToolkit`
  [1dea7af3] OrdinaryDiffEq v5.50.2
isaacsas commented 3 years ago

Here is my global environment FWIW:

  [7d9fca2a] Arpack v0.5.1
  [4fba245c] ArrayInterface v2.14.17
  [4c555306] ArrayLayouts v0.4.12
  [aae01518] BandedMatrices v0.15.25
  [6e4b80f9] BenchmarkTools v0.5.0
  [ffab5731] BlockBandedMatrices v0.9.5
  [336ed68f] CSV v0.8.3
  [5d742f6a] CSVFiles v1.0.0
  [159f3aea] Cairo v1.0.5
  [479239e8] Catalyst v6.6.0
  [134e5e36] Catlab v0.10.2
  [a93c6f00] DataFrames v0.22.5
  [864edb3b] DataStructures v0.18.9
  [31a5f54b] Debugger v0.6.7
  [2b5f629d] DiffEqBase v6.57.5
  [459566f4] DiffEqCallbacks v2.16.0
  [aae7a2af] DiffEqFlux v1.32.0
  [c894b116] DiffEqJump v6.13.0
  [77a26b50] DiffEqNoiseProcess v5.5.2
  [0c46a032] DifferentialEquations v6.16.0
  [31c24e10] Distributions v0.24.12
  [e30172f5] Documenter v0.26.1
  [35a29f4d] DocumenterTools v0.1.9
  [497a8b3b] DoubleFloats v1.1.15
  [7a1cc6ca] FFTW v1.3.0
  [5789e2e9] FileIO v1.4.5
  [1a297f60] FillArrays v0.10.2
  [53c48c17] FixedPointNumbers v0.8.4
  [587475ba] Flux v0.11.1
  [f6369f11] ForwardDiff v0.10.16
  [28b8d3ca] GR v0.53.0
  [3c863552] Graphviz_jll v2.42.3+1
  [34004b35] HypergeometricFunctions v0.3.5
  [09f84164] HypothesisTests v0.10.2
  [7073ff75] IJulia v1.23.1
  [82e4d734] ImageIO v0.4.1
  [6218d12a] ImageMagick v1.1.6
  [916415d5] Images v0.23.3
  [d1acc4aa] IntervalArithmetic v0.17.7
  [d2bf35a9] IntervalRootFinding v0.5.5
  [e5e0dc1b] Juno v0.8.4
  [b964fa9f] LaTeXStrings v1.2.0
  [23fbe1c1] Latexify v0.14.7
  [d7e5e226] LazyBandedMatrices v0.3.6
  [093fc24a] LightGraphs v1.3.5
  [2fda8390] LsqFit v0.12.0
  [23992714] MAT v0.9.2
  [b51810bb] MatrixDepot v1.0.3
  [961ee093] ModelingToolkit v5.6.0
  [2774e3e8] NLsolve v4.5.1
  [47be7bcc] ORCA v0.5.0
  [1dea7af3] OrdinaryDiffEq v5.50.2
  [8314cec4] PGFPlotsX v1.2.10
  [ccf2f8ad] PlotThemes v2.0.1
  [58dd65bb] Plotly v0.3.0
  [a03496cd] PlotlyBase v0.4.3
  [f0f68f2c] PlotlyJS v0.14.0
  [91a5bcdd] Plots v1.10.4
  [c3e4b0f8] Pluto v0.12.20
  [7f904dfe] PlutoUI v0.6.11
  [08abe8d2] PrettyTables v0.11.0
  [c46f51b8] ProfileView v0.6.9
  [438e738f] PyCall v1.92.2
  [d330b81b] PyPlot v2.9.0
  [1fd47b50] QuadGK v2.4.1
  [be4d8f0f] Quadmath v0.5.5
  [dca85d43] QuartzImageIO v0.7.3
  [e6cf234a] RandomNumbers v1.4.0
  [731186ca] RecursiveArrayTools v2.11.0
  [295af30f] Revise v3.1.11
  [f2b01f46] Roots v1.0.8
  [1bc83da4] SafeTestsets v0.0.1
  [276daf66] SpecialFunctions v1.2.1
  [90137ffa] StaticArrays v1.0.1
  [2913bbd2] StatsBase v0.33.2
  [4c63d2b9] StatsFuns v0.9.6
  [9672c7b4] SteadyStateDiffEq v1.6.1
  [789caeaf] StochasticDiffEq v6.32.1
  [c3572dad] Sundials v4.4.1
  [286e6d88] SymRCM v0.2.1
  [d1185830] SymbolicUtils v0.7.8
  [bd369af6] Tables v1.3.2
  [a759f4b9] TimerOutputs v0.5.7
  [d94bfb22] TrackingHeaps v0.1.0 `https://github.com/henriquebecker91/TrackingHeaps.jl#master`
  [3a884ed6] UnPack v1.0.2
  [1986cc42] Unitful v1.5.0
  [44d3d7a6] Weave v0.10.6
  [0f1e0344] WebIO v0.8.15
  [1270edf5] x264_jll v2020.7.14+2
  [37e2e46d] LinearAlgebra
ChrisRackauckas commented 3 years ago

Works for me too. @TorkelE just needs to update.