This is fundamentally the wrong approach. What you're looking for is . DiffEqGPU.jl will take small ODE problems and solve them simultaneously for different parameters and initial conditions on a GPU. You'll want to mix this with forward mode AD (i.e. use sensealg=ForwardDiffSensitivity()
) which will be more efficient for parameter numbers of this amount per ODE.
Ah I see. I tried DiffEqGPU.jl to no avail, I opened a separate issue there regarding GPU usage SciML/DiffEqGPU.jl#57
I get the following error when I try to use EnsembleThreads()
with or without ForwardDiffSensitivity
Need an adjoint for constructor EnsembleSolution{Float32,3,Array{ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Tuple{Float32,Float32,Float32},ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},1}}. Gradient is of type Array{Float64,3}
while using the following code:
using Pkg
using DiffEqGPU, DifferentialEquations, Flux, DiffEqSensitivity
function lorenz(du,u,p,t)
@inbounds begin
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
nBatchsize = 10
nEpochs = 10
tspan = 100.0
tsize = 1000
nbatches = div(tsize,nBatchsize)
t = range(0.0,tspan,length=tsize)
NN_encode = Chain(Dense(3,10,tanh),Dense(10,10,tanh),Dense(10,3,tanh))
NN_decode = Chain(Dense(3,10),Dense(10,10,tanh),Dense(10,3,tanh))
u0 = Float32[1.0, 0.0, 0.0]
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy)
args_ = Dict()
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lorenz,u0,(0.0f0,Float32(tspan/nbatches)),p)
function ensemble_solve(x)
prob_func = (prob,i,repeat) -> remake(prob,u0=x[:,i])
monteprob = EnsembleProblem(prob2, prob_func = prob_func)
return Array(solve(monteprob,Tsit5(),EnsembleThreads(),trajectories=nbatches,saveat=t_batch,sensealg=ForwardDiffSensitivity()))
function loss_func(data_)
enc_ = NN_encode(data_)
enc_ODE_solve = hcat([ensemble_solve(enc_[:,1:nBatchsize:end])[:,:,i] for i in 1:nbatches]...)
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
opt = ADAM(0.001)
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
Hmm, Zygote is having issues here.
using ZygoteRules
ZygoteRules.@adjoint function EnsembleSolution(sim,time,converged)
gets you pretty far. The next issue you hit is that Zygote doesn't seem to work with @threads
, so I just changed it to EnsembleSerial()
to keep moving. Then, Zygote wasn't able to compile code with @warn
( because that has a try-catch in there, so I commented those out. That made it start calling the adjoint equation, which then errored for a reason I don't understand yet, but that's very close.
So the way to solve this would be to fix:
etc.The first two should get MWEs on Zygote. That adjoint should get added to DiffEq, and @dhairyagandhi96 I might want help debugging the last part. This is generally something that would be good to have, and the knock-on effects of fixing this case are likely very valuable (pmap/tmap adjoints are probably more broadly useful, so we should look at that first).
I see, I'm afraid I'm presently incapacitated to dig through the Zygote issue. I do however, have a temporary workaround on GPU whilst training with initial conditions as well. The RHS can be better defined but gradients are being computed successfully:
using Pkg
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations
# Training hyperparameters
nBatchsize = 30
nEpochs = 10
tspan = 20.0
tsize = 300
nbatches = div(tsize,nBatchsize)
statesize = 2
# ODE solve
function lotka_volterra_func!(du,u,p,t)
du[1] = u[1]*(p[1]-p[2]*u[2])
du[2] = u[2]*(p[3]*u[1]-p[4])
return du
function lotka_volterra(du,u,p,t)
dx = Zygote.Buffer(du,size(du)[1])
for i in 1:nbatches
index = (i-1)*statesize
dx[index+1:index+statesize] = lotka_volterra_func!(dx[index+1:index+statesize],u[index+1:index+statesize],p,t)
du .= copy(dx)
# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4]
u0 = Float32[0.01, 0.01]
t = range(0.0,tspan,length=tsize)
# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra_func!,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy) |> gpu
# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
u0_train = rand(statesize*nbatches)
# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0_train,(0.0f0,Float32(tspan/nbatches)),p)
#ODE solve to be used for training
function predict_ODE_solve()
return Array(solve(prob2,Tsit5(),saveat=t_batch,reltol=1e-4))
function loss_func(data_)
enc_ = NN_encode(data_)
# Solve ODE using initial values from multiple points in enc_.
# Note: reduce(hcat,[..]) gives a mutating arrays error
#enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) #|> gpu
enc_ODE_solve = hcat([predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches]...) |> gpu
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
opt = ADAM(0.001)
loss_func(data) # This works
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode,u0_train), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
This can be reduce(hcat,[predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches])
Update: Zygote is now compatible with parallelism ( so this should be possible now.
Same as user on Discourse:
using DifferentialEquations, Flux
pa = [1.0]
function model1(input)
prob = ODEProblem((u, p, t) -> 1.01u * pa[1], 0.5, (0.0, 1.0))
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 100)
Input_time_series = zeros(5, 100)
# loss function
loss(x, y) = Flux.mse(model1(x), y)
data = Iterators.repeated((Input_time_series, 0), 1)
cb = function () # callback function to observe training
println("Tracked Parameters: ", params(pa))
opt = ADAM(0.1)
println("Starting to train")
Flux.@epochs 10 Flux.train!(loss, params(pa), data, opt; cb = cb)
is a good MWE
@DhairyaLGandhi can I get help completing this? It just needs some tweaks to the adjoint definitions now:
using OrdinaryDiffEq, DiffEqSensitivity, Flux,
using ZygoteRules
ZygoteRules.@adjoint EnsembleSolution(sim,time,converged) = EnsembleSolution(sim,time,converged), p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) = sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
pa = [1.0]
u0 = [1.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() .* prob.u0)
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
# loss function
loss() = sum(abs2,1.0.-model1())
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
opt = ADAM(0.05)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test l1 < 10l2
Then EnsembleDistributed()
work because of is somewhat related, since we really need adjoints of the solution types and the literal_getproperty calls.
This almost works if you comment out the warns ( in there.
The two adjoints defined in
I can take a look in a bit, I'm a bit flocked at the moment.
Edit: you probably want (EnsembleSolution(...), nothing)
for the second adjoint at a quick glance?
alright thanks fixes this. Final test:
using OrdinaryDiffEq, DiffEqSensitivity, Flux
pa = [1.0]
u0 = [3.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
# loss function
loss() = sum(abs2,1.0.-Array(model1()))
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1
function model2()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()])
pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1
I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.
Say we have a time series
of dimension2x100
and I'd like to solve the ODE over small time intervals[0,10]
using initial conditionsy[:,1:10:end]
. It works fine on the cpu usinghcat([Array(solve(...))]...)
, however using the gpu gives me the error:here is the code:
and here is the current status of packages
Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)