SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.85k stars 226 forks source link

Thread-savety for PeriodicCallback #692

Open frankschae opened 3 years ago

frankschae commented 3 years ago

I'd like to apply a controller periodically in an ensemble simulation. The controller should act based on the current state of the integrator and change a parameter of the DE. However, PeriodicCallback doesn't seem to be thread save, despite using safetycopy = true.

using Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations
using Plots

function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
dt = 0.1

nn = Chain(Dense(2, 1, tanh))
p_nn, re = Flux.destructure(nn)

pars = [1.5, 1.0, 3.0, 1.0]

prob = ODEProblem{true}(lotka_volterra!, u0, tspan, pars)

sol = solve(prob, Tsit5(), adaptive=false, dt=0.05)

function loss(p)
  function affect!(integrator)
    integrator.p[2]=(re(p)(integrator.u))[1]
  end
  cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))

  sol = solve(prob, Tsit5(), saveat = dt, callback=cb, adaptive=false, dt=0.05)
  loss = sum(abs2, sol.-1)
  return loss, sol
end

l1, sol1 = loss(p_nn)

function ensembleloss(p)
  function affect!(integrator)
    integrator.p[2]=(re(p)(integrator.u))[1]
  end
  cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))

  ensembleprob = EnsembleProblem(prob,
    safetycopy = true
    )

  sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(), saveat = dt,
     callback=cb, adaptive=false, dt=0.05, trajectories = 5)
  loss = sum(abs2, sol.-1)/5
  return loss, sol
end

Threads.nthreads() = 1
l2, sol2 = ensembleloss(p_nn)

with Threads.nthreads() = 1, the loss values l1 and l2 are the same (up to 1e-13). with Threads.nthreads() = 4, the loss values and the associated trajectories can be different.

Sometimes (?), I get an error:

ERROR: TaskFailedException:
Tried to add a tstop that is behind the current time. This is strictly forbidden
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] add_tstop! at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_interface.jl:96 [inlined]
 [3] (::DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}})(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/DiffEqCallbacks/b4ahb/src/iterative_and_periodic.jl:84
 [4] apply_discrete_callback! at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/callbacks.jl:830 [inlined]
 [5] handle_callbacks!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:259
 [6] _loopfooter!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:220
 [7] loopfooter! at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:166 [inlined]
 [8] solve!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/solve.jl:429
 [9] #__solve#391 at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/solve.jl:5 [inlined]
 [10] solve_call(::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, ::Tsit5; merge_callbacks::Bool, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:92
 [11] #solve_up#461 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:114 [inlined]
 [12] #solve#460 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:102 [inlined]
 [13] batch_func(::Int64, ::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:146
 [14] #363 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:180 [inlined]
 [15] iterate at ./generator.jl:47 [inlined]
 [16] _collect(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},DiffEqBase.var"#363#364"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
 [17] collect_similar at ./array.jl:628 [inlined]
 [18] map at ./abstractarray.jl:2162 [inlined]
 [19] solve_batch(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleSerial, ::UnitRange{Int64}, ::Int64; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:179
 [20] (::DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64})(::Int64) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:206
 [21] macro expansion at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:214 [inlined]
 [22] (::DiffEqBase.var"#509#threadsfor_fun#370"{DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Any,1},UnitRange{Int64}})(::Bool) at ./threadingconstructs.jl:81
 [23] (::DiffEqBase.var"#509#threadsfor_fun#370"{DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Any,1},UnitRange{Int64}})() at ./threadingconstructs.jl:48
Stacktrace:
 [1] wait at ./task.jl:267 [inlined]
 [2] threading_run(::Function) at ./threadingconstructs.jl:34
 [3] macro expansion at ./threadingconstructs.jl:93 [inlined]
 [4] tmap(::Function, ::UnitRange{Int64}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:213
 [5] solve_batch(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleThreads, ::UnitRange{Int64}, ::Int64; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:200
 [6] batch_function at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
 [7] macro expansion at ./timing.jl:233 [inlined]
 [8] __solve(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleThreads; trajectories::Int64, batch_size::Int64, pmap_batch_size::Int64, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:112
 [9] __solve(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{6,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt, :trajectories),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64,Int64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:87
 [10] #solve#462 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:128 [inlined]
 [11] ensembleloss(::Array{Float32,1}) at /Users/frank/switchdrive/Institution/stochastic_control/ODE_control/threadsafetytest.jl:59
 [12] top-level scope at none:1
frankschae commented 3 years ago

Updated code based on https://github.com/SciML/DifferentialEquations.jl/issues/646 (moving the callback to the problem type instead of the solve call). If I choose the number of trajectories numtraj >= Threads.nthreads(), some trajectories are different and if numtraj is increased to far, I obtain the error "Tried to add a tstop that is behind the current time.. ". ( ensembleloss(..) and ensembleloss2(..) show the same behaviour.)

# load packages
using Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations. ### version: DifferentialEquations v6.15.0
using Plots
using LinearAlgebra
using Test, Random

function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
dt = 0.1

Random.seed!(10)
nn = Chain(Dense(2, 1, relu))
p_nn, re = Flux.destructure(nn)

pars = [1.5, 1.0, 3.0, 1.0]

function affect!(integrator)
  integrator.p[2]=(re(p_nn)(integrator.u))[1]
end
cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))
prob = ODEProblem{true}(lotka_volterra!, u0, tspan, pars, callback=cb)

sol = solve(prob, Tsit5(), adaptive=true, dt=0.001, saveat=dt)

plot(sol)
@show sum(abs2, sol.-1)

function loss(p; sensealg=ForwardDiffSensitivity())
  function affect2!(integrator)
    Ω = (re(p)(integrator.u))[1]
    integrator.p[2] = Ω
  end
  cb2 = PeriodicCallback(affect2!,dt;initial_affect=true,save_positions=(false,false))
  tmp_prob = remake(prob, callback=cb2)

  sol = solve(tmp_prob, Tsit5(), sensealg=sensealg, saveat = dt, adaptive=true, dt=0.001)
  loss = sum(abs2, sol.-1)
  return loss, sol
end

l1, sol1 = loss(p_nn)

plot(sol1)

function ensembleloss(p; numtraj=5, sensealg=ForwardDiffSensitivity())
  function affect3!(integrator)
    integrator.p[2]=(re(p)(integrator.u))[1]
  end
  cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))

  function prob_func(prob,i,repeat)
    remake(prob,callback = cb3)
  end

  ensembleprob = EnsembleProblem(prob,
    prob_func = prob_func, safetycopy = true
    )

  sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
    sensealg = sensealg,
    saveat = dt,
    adaptive=true, dt=0.001, trajectories = numtraj)
  loss = sum(abs2, sol.-1)/numtraj
  return loss, sol
end

l2, sol2 = ensembleloss(p_nn, numtraj=2)

plot!(sol2)
@test isapprox(l1, l2, atol=1e-10)

function ensembleloss2(p; numtraj=5, sensealg=ForwardDiffSensitivity())
  function affect3!(integrator)
    integrator.p[2]=(re(p)(integrator.u))[1]
  end
  cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))

  tmp_prob = remake(prob, callback = cb3)

  ensembleprob = EnsembleProblem(tmp_prob)

  sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
    sensealg = sensealg,
    saveat = dt,
    adaptive=true, dt=0.001, trajectories = numtraj)
  loss = sum(abs2, sol.-1)/numtraj
  return loss, sol
end

l3, sol3 = ensembleloss2(p_nn, numtraj=10)

plot(sol3)
@test isapprox(l1, l3, atol=1e-10)

@show Threads.nthreads() ### 4