Closed TorkelE closed 3 years ago
What happens if you create the ODEProblem using Catalyst outside the loss function (maybe pass it as a parameter with a closure)?
Yeah I wouldn't differentiate the symbolic construction. Just use remake
.
Maybe I misunderstood, but this also errors:
@parameters A B t
@variables X(t) Y(t)
rxs = [Reaction(A, nothing, [X], nothing, [1])
Reaction(1., [X,Y], [X], [2,1],[3])
Reaction(B, [X], [Y], [1],[1])
Reaction(1., [X], nothing, [1],nothing)]
brueeslator_MTK = ReactionSystem(rxs, t, [X,Y], [A,B])
brusselator_catalyst = @reaction_network begin
A, ∅ → X
1, 2X + Y → 3X
B, X → Y
1, X → ∅
end A B
function brusselator_function(du, u, p, t)
X, Y = u
A, B = p
du[1] = dx = A + 0.5Y*X^2 -B*X -X
du[2] = dy = B*X - 0.5Y*X^2
end
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob_MTK = ODEProblem(brueeslator_MTK,u0,tspan,[1.,1])
prob_func = ODEProblem(brusselator_function,u0,tspan,[1.,1])
function loss(p)
sol = solve(remake(prob_MTK,p=p), Rosenbrock23())
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
with probably the most gargantuan error message I've seen. It starts of something like
MethodError: ReverseDiff.TrackedReal{ForwardDiff.Dual{Forwa...
For reference, this works:
function loss(p)
sol = solve(remake(prob_func,p=p), Rosenbrock23())
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
I tried to play with this on the latest release, but it seems ReactionSystems
are broken right now, see https://github.com/SciML/ModelingToolkit.jl/issues/779.
I was going to see if you could get it to work if you directly built an ODEFunction
from brueeslator_MTK
and then created the ODEProblem
using just the rhs function within the ODEFuction
. That would test if the problem is the generated rhs function, or something in the wrapping ODEProblem/ODEFunction.
Did the test, the error persists, which does give some hint:
mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
prob_MTK_odefun = ODEProblem(mtk_odefun,u0,tspan,[1.,1])
function loss(p)
sol = solve(remake(prob_MTK_odefun,p=p), Rosenbrock23())
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
MethodError: ReverseDiff.TrackedReal{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,ReverseDiff.GradientTape{DiffEqSensitivity.var"#75#84"{ODEFunction{true,ModelingToolkit.var"#f#225"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##MTKArg#554"), Symbol("##MTKArg#555"), Symbol("##MTKArg#556")),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",
...
Sorry I wasn't clear! I meant to use
mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])
which should only be using the generated ODE rhs from MTK.
ahh, so it would be:
mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])
function loss(p)
sol = solve(remake(oprob,p=p), Rosenbrock23())
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
(still, the same error appears though)
OK, that seems to suggest it is a build_function
issue. @shashi any thoughts?
What about with Rosenbrock23(autodiff=false)
? It would be good to simplify this as much as possible.
Now we are getting there! Yes, this works:
mtk_odesys = convert(ODESystem,brueeslator_MTK)
mtk_odefun = ODEFunction(mtk_odesys)
oprob = ODEProblem(mtk_odefun.f,u0,tspan,[1.,1])
function loss(p)
sol = solve(remake(oprob,p=p), Rosenbrock23(autodiff=false))
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
@TorkelE If you directly build the ODEs in MTK do you have this issue? (i.e. don't use ReactionSystem
at all, but manually enter the symbolic ODEs.)
If it persists there it would point to some issue with build_function
and AD types or at least ODESystem
s, if it goes away then that would seem to indicate it is an issue with how we are generating ODEs from ReactionSystems
.
Might be already in the ODESystem
, this generates an error:
using ModelingToolkit, OrdinaryDiffEq, DiffEqFlux, Flux
@parameters t A B
@variables X(t) Y(t)
D = Differential(t)
eqs = [D(X) ~ A + 0.5*Y*X*X - B*X - X,
D(Y) ~ B*X - 0.5*Y*X*X]
sys = ODESystem(eqs)
sys = ode_order_lowering(sys)
u0 = [X => 1.0, Y => 1.0]
p = [A => 1.0, B => 1.0]
tspan = (0.0, 10.0)
prob_odesys = ODEProblem(ODESystem(eqs),u0,tspan,p)
function loss(p)
sol = solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
Is it a Catalyst issue? Or just an ODESystem issue? This may boil down to being a RuntimeGeneratedFunction issue with reverse mode AD?
solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
You mean false?
This seems like an issue with AD types in the generated rhs function. So maybe it’s a RuntimeGeneratedFunction error. I don’t think there are issues when using finite differences in the solvers.
@TorkelE seems to have shown this is not an issue related to Catalyst or ReactionSystems.
But ForwardDiff is fine? Was that isolated? Sounds like a ReverseDiff issue to me given how it holds functions. Can you remove the involvement of ReverseDiff. So test:
solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=false),sensealg=InterpolatingAdjoint(autojacvec=false))
and see which errors?
I'll give it a shot later this afternoon on the lastest MTK master.
Hmm, I can't reproduce the reported errors. I tried the following two examples on the latest MTK master and it went through fine
using ModelingToolkit, OrdinaryDiffEq, DiffEqFlux, Flux
@parameters t A B
@variables X(t) Y(t)
D = Differential(t)
eqs = [D(X) ~ A + 0.5*Y*X*X - B*X - X,
D(Y) ~ B*X - 0.5*Y*X*X]
sys = ODESystem(eqs)
sys = ode_order_lowering(sys)
u0 = [X => 1.0, Y => 1.0]
p = [A => 1.0, B => 1.0]
tspan = (0.0, 10.0)
prob_odesys = ODEProblem(ODESystem(eqs),u0,tspan,p)
function loss(p)
sol = solve(remake(prob_odesys,p=p), Rosenbrock23(autodiff=true))
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
@parameters A B t
@variables X(t) Y(t)
rxs = [Reaction(A, nothing, [X], nothing, [1])
Reaction(1., [X,Y], [X], [2,1],[3])
Reaction(B, [X], [Y], [1],[1])
Reaction(1., [X], nothing, [1],nothing)]
brueeslator_MTK = ReactionSystem(rxs, t, [X,Y], [A,B])
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob_MTK = ODEProblem(brueeslator_MTK,u0,tspan,[1.,1])
function loss(p)
sol = solve(remake(prob_MTK,p=p), Rosenbrock23())
loss = sum(abs2, sol) # (Was initially optimizing against some data, but this makes the code shorter...)
return loss, sol
end
DiffEqFlux.sciml_train(loss,[1.,1.],ADAM(0.1),maxiters = 100)
@TorkelE what MTK version are you on? Have you tried the latest release or master?
Ok, this is getting a bit weird. I run your code, it still errors, this is the output of Pkg.status()
Status `~/Desktop/ParamEstimExample/Project.toml`
[479239e8] Catalyst v6.4.0
[2445eb08] DataDrivenDiffEq v0.5.4
[aae7a2af] DiffEqFlux v1.31.0
[1130ab10] DiffEqParamEstim v1.19.1
[41bf760c] DiffEqSensitivity v6.40.0
[0c46a032] DifferentialEquations v6.16.0
[587475ba] Flux v0.11.6
[23fbe1c1] Latexify v0.14.7
[961ee093] ModelingToolkit v5.6.1 `~/.julia/dev/ModelingToolkit`
[429524aa] Optim v1.2.3
[91a5bcdd] Plots v1.10.2
[731186ca] RecursiveArrayTools v2.11.0
I think that's the latest MTK version as well.
I was running in the the Project.toml from ModelingToolkit. I get
Project ModelingToolkit v5.6.1
Status `~/.julia/dev/ModelingToolkit/Project.toml`
[4fba245c] ArrayInterface v3.1.1
[864edb3b] DataStructures v0.18.9
[2b5f629d] DiffEqBase v6.57.5
[c894b116] DiffEqJump v6.13.0
[b552c78f] DiffRules v1.0.2
[31c24e10] Distributions v0.24.12
[ffbed154] DocStringExtensions v0.8.3
[615f187c] IfElse v0.1.0
[2ee39098] LabelledArrays v1.5.0
[23fbe1c1] Latexify v0.14.7
[093fc24a] LightGraphs v1.3.5
[1914dd2f] MacroTools v0.5.6
[77ba4419] NaNMath v0.3.5
[731186ca] RecursiveArrayTools v2.11.0
[189a3867] Reexport v1.0.0
[ae029012] Requires v1.1.2
[7e49a35a] RuntimeGeneratedFunctions v0.5.1
[1bc83da4] SafeTestsets v0.0.1
[0bca4576] SciMLBase v1.7.3
[efcf1570] Setfield v0.7.0
[276daf66] SpecialFunctions v1.2.1
[90137ffa] StaticArrays v1.0.1
[d1185830] SymbolicUtils v0.8.2
[a2a6695c] TreeViews v0.3.0
[3a884ed6] UnPack v1.0.2
[1986cc42] Unitful v1.5.0
[8ba89e20] Distributed
[8f399da3] Libdl
[37e2e46d] LinearAlgebra
[2f01184e] SparseArrays
You would have to import DiffEqFlux and Flux to that? I will see if I can manage to run from that one as well.
I made a new environment and added them. This is what I get, I guess this doesn't show the version on the indirect references though (which may be installed in my global 1.5 environment and reused?):
[aae7a2af] DiffEqFlux v1.32.0
[961ee093] ModelingToolkit v5.6.1 `~/.julia/dev/ModelingToolkit`
[1dea7af3] OrdinaryDiffEq v5.50.2
Here is my global environment FWIW:
[7d9fca2a] Arpack v0.5.1
[4fba245c] ArrayInterface v2.14.17
[4c555306] ArrayLayouts v0.4.12
[aae01518] BandedMatrices v0.15.25
[6e4b80f9] BenchmarkTools v0.5.0
[ffab5731] BlockBandedMatrices v0.9.5
[336ed68f] CSV v0.8.3
[5d742f6a] CSVFiles v1.0.0
[159f3aea] Cairo v1.0.5
[479239e8] Catalyst v6.6.0
[134e5e36] Catlab v0.10.2
[a93c6f00] DataFrames v0.22.5
[864edb3b] DataStructures v0.18.9
[31a5f54b] Debugger v0.6.7
[2b5f629d] DiffEqBase v6.57.5
[459566f4] DiffEqCallbacks v2.16.0
[aae7a2af] DiffEqFlux v1.32.0
[c894b116] DiffEqJump v6.13.0
[77a26b50] DiffEqNoiseProcess v5.5.2
[0c46a032] DifferentialEquations v6.16.0
[31c24e10] Distributions v0.24.12
[e30172f5] Documenter v0.26.1
[35a29f4d] DocumenterTools v0.1.9
[497a8b3b] DoubleFloats v1.1.15
[7a1cc6ca] FFTW v1.3.0
[5789e2e9] FileIO v1.4.5
[1a297f60] FillArrays v0.10.2
[53c48c17] FixedPointNumbers v0.8.4
[587475ba] Flux v0.11.1
[f6369f11] ForwardDiff v0.10.16
[28b8d3ca] GR v0.53.0
[3c863552] Graphviz_jll v2.42.3+1
[34004b35] HypergeometricFunctions v0.3.5
[09f84164] HypothesisTests v0.10.2
[7073ff75] IJulia v1.23.1
[82e4d734] ImageIO v0.4.1
[6218d12a] ImageMagick v1.1.6
[916415d5] Images v0.23.3
[d1acc4aa] IntervalArithmetic v0.17.7
[d2bf35a9] IntervalRootFinding v0.5.5
[e5e0dc1b] Juno v0.8.4
[b964fa9f] LaTeXStrings v1.2.0
[23fbe1c1] Latexify v0.14.7
[d7e5e226] LazyBandedMatrices v0.3.6
[093fc24a] LightGraphs v1.3.5
[2fda8390] LsqFit v0.12.0
[23992714] MAT v0.9.2
[b51810bb] MatrixDepot v1.0.3
[961ee093] ModelingToolkit v5.6.0
[2774e3e8] NLsolve v4.5.1
[47be7bcc] ORCA v0.5.0
[1dea7af3] OrdinaryDiffEq v5.50.2
[8314cec4] PGFPlotsX v1.2.10
[ccf2f8ad] PlotThemes v2.0.1
[58dd65bb] Plotly v0.3.0
[a03496cd] PlotlyBase v0.4.3
[f0f68f2c] PlotlyJS v0.14.0
[91a5bcdd] Plots v1.10.4
[c3e4b0f8] Pluto v0.12.20
[7f904dfe] PlutoUI v0.6.11
[08abe8d2] PrettyTables v0.11.0
[c46f51b8] ProfileView v0.6.9
[438e738f] PyCall v1.92.2
[d330b81b] PyPlot v2.9.0
[1fd47b50] QuadGK v2.4.1
[be4d8f0f] Quadmath v0.5.5
[dca85d43] QuartzImageIO v0.7.3
[e6cf234a] RandomNumbers v1.4.0
[731186ca] RecursiveArrayTools v2.11.0
[295af30f] Revise v3.1.11
[f2b01f46] Roots v1.0.8
[1bc83da4] SafeTestsets v0.0.1
[276daf66] SpecialFunctions v1.2.1
[90137ffa] StaticArrays v1.0.1
[2913bbd2] StatsBase v0.33.2
[4c63d2b9] StatsFuns v0.9.6
[9672c7b4] SteadyStateDiffEq v1.6.1
[789caeaf] StochasticDiffEq v6.32.1
[c3572dad] Sundials v4.4.1
[286e6d88] SymRCM v0.2.1
[d1185830] SymbolicUtils v0.7.8
[bd369af6] Tables v1.3.2
[a759f4b9] TimerOutputs v0.5.7
[d94bfb22] TrackingHeaps v0.1.0 `https://github.com/henriquebecker91/TrackingHeaps.jl#master`
[3a884ed6] UnPack v1.0.2
[1986cc42] Unitful v1.5.0
[44d3d7a6] Weave v0.10.6
[0f1e0344] WebIO v0.8.15
[1270edf5] x264_jll v2020.7.14+2
[37e2e46d] LinearAlgebra
Works for me too. @TorkelE just needs to update.
Was trying to get parameter fitting to work for a Catalyst model. I ran through everything first with the Lotka Volterra and it all worked fine, but then when I exchanged the Lotka Volterra function for a Catalyst model it stops working. The error I get is a:
This is a minimal example, using a ReactionSystem created directly:
Then if I optimise for the normal function:
it works fine.
But if I try the MTK one, it errors:
which produces a
I get a similar error when I use a catalyst model:
and however I input u0/parameters:
I get the same error if I try to convert to an
ODESystem
:or an
ODEFunction