Open frankschae opened 3 years ago
Updated code based on https://github.com/SciML/DifferentialEquations.jl/issues/646 (moving the callback to the problem type instead of the solve
call). If I choose the number of trajectories numtraj
>= Threads.nthreads()
, some trajectories are different and if numtraj
is increased to far, I obtain the error "Tried to add a tstop that is behind the current time.. ". ( ensembleloss(..)
and ensembleloss2(..)
show the same behaviour.)
# load packages
using Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations. ### version: DifferentialEquations v6.15.0
using Plots
using LinearAlgebra
using Test, Random
function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
# Initial condition
u0 = [1.0, 1.0]
# Simulation interval and intermediary points
tspan = (0.0, 10.0)
dt = 0.1
Random.seed!(10)
nn = Chain(Dense(2, 1, relu))
p_nn, re = Flux.destructure(nn)
pars = [1.5, 1.0, 3.0, 1.0]
function affect!(integrator)
integrator.p[2]=(re(p_nn)(integrator.u))[1]
end
cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))
prob = ODEProblem{true}(lotka_volterra!, u0, tspan, pars, callback=cb)
sol = solve(prob, Tsit5(), adaptive=true, dt=0.001, saveat=dt)
plot(sol)
@show sum(abs2, sol.-1)
function loss(p; sensealg=ForwardDiffSensitivity())
function affect2!(integrator)
Ω = (re(p)(integrator.u))[1]
integrator.p[2] = Ω
end
cb2 = PeriodicCallback(affect2!,dt;initial_affect=true,save_positions=(false,false))
tmp_prob = remake(prob, callback=cb2)
sol = solve(tmp_prob, Tsit5(), sensealg=sensealg, saveat = dt, adaptive=true, dt=0.001)
loss = sum(abs2, sol.-1)
return loss, sol
end
l1, sol1 = loss(p_nn)
plot(sol1)
function ensembleloss(p; numtraj=5, sensealg=ForwardDiffSensitivity())
function affect3!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))
function prob_func(prob,i,repeat)
remake(prob,callback = cb3)
end
ensembleprob = EnsembleProblem(prob,
prob_func = prob_func, safetycopy = true
)
sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
sensealg = sensealg,
saveat = dt,
adaptive=true, dt=0.001, trajectories = numtraj)
loss = sum(abs2, sol.-1)/numtraj
return loss, sol
end
l2, sol2 = ensembleloss(p_nn, numtraj=2)
plot!(sol2)
@test isapprox(l1, l2, atol=1e-10)
function ensembleloss2(p; numtraj=5, sensealg=ForwardDiffSensitivity())
function affect3!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))
tmp_prob = remake(prob, callback = cb3)
ensembleprob = EnsembleProblem(tmp_prob)
sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
sensealg = sensealg,
saveat = dt,
adaptive=true, dt=0.001, trajectories = numtraj)
loss = sum(abs2, sol.-1)/numtraj
return loss, sol
end
l3, sol3 = ensembleloss2(p_nn, numtraj=10)
plot(sol3)
@test isapprox(l1, l3, atol=1e-10)
@show Threads.nthreads() ### 4
I'd like to apply a controller periodically in an ensemble simulation. The controller should act based on the current state of the integrator and change a parameter of the DE. However,
PeriodicCallback
doesn't seem to be thread save, despite usingsafetycopy = true
.with
Threads.nthreads() = 1
, the loss valuesl1
andl2
are the same (up to 1e-13). withThreads.nthreads() = 4
, the loss values and the associated trajectories can be different.Sometimes (?), I get an error: