Closed mkalia94 closed 4 years ago
Hey,
This is fundamentally the wrong approach. What you're looking for is https://github.com/SciML/DiffEqGPU.jl . DiffEqGPU.jl will take small ODE problems and solve them simultaneously for different parameters and initial conditions on a GPU. You'll want to mix this with forward mode AD (i.e. use sensealg=ForwardDiffSensitivity()
) which will be more efficient for parameter numbers of this amount per ODE.
Ah I see. I tried DiffEqGPU.jl to no avail, I opened a separate issue there regarding GPU usage SciML/DiffEqGPU.jl#57
I get the following error when I try to use EnsembleThreads()
with or without ForwardDiffSensitivity
Need an adjoint for constructor EnsembleSolution{Float32,3,Array{ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Tuple{Float32,Float32,Float32},ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},1}}. Gradient is of type Array{Float64,3}
while using the following code:
using Pkg
Pkg.activate(".")
using DiffEqGPU, DifferentialEquations, Flux, DiffEqSensitivity
function lorenz(du,u,p,t)
@inbounds begin
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
end
nothing
end
nBatchsize = 10
nEpochs = 10
tspan = 100.0
tsize = 1000
nbatches = div(tsize,nBatchsize)
t = range(0.0,tspan,length=tsize)
NN_encode = Chain(Dense(3,10,tanh),Dense(10,10,tanh),Dense(10,3,tanh))
NN_decode = Chain(Dense(3,10),Dense(10,10,tanh),Dense(10,3,tanh))
u0 = Float32[1.0, 0.0, 0.0]
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy)
args_ = Dict()
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lorenz,u0,(0.0f0,Float32(tspan/nbatches)),p)
function ensemble_solve(x)
prob_func = (prob,i,repeat) -> remake(prob,u0=x[:,i])
monteprob = EnsembleProblem(prob2, prob_func = prob_func)
return Array(solve(monteprob,Tsit5(),EnsembleThreads(),trajectories=nbatches,saveat=t_batch,sensealg=ForwardDiffSensitivity()))
end
function loss_func(data_)
enc_ = NN_encode(data_)
enc_ODE_solve = hcat([ensemble_solve(enc_[:,1:nBatchsize:end])[:,:,i] for i in 1:nbatches]...)
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
end
ensemble_solve(data)
loss_func(data)
opt = ADAM(0.001)
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
end
Hmm, Zygote is having issues here.
using ZygoteRules
ZygoteRules.@adjoint function EnsembleSolution(sim,time,converged)
EnsembleSolution(sim,time,converged),y->(y,nothing,nothing)
end
gets you pretty far. The next issue you hit is that Zygote doesn't seem to work with @threads
, so I just changed it to EnsembleSerial()
to keep moving. Then, Zygote wasn't able to compile code with @warn
(https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L125-L152) because that has a try-catch in there, so I commented those out. That made it start calling the adjoint equation, which then errored for a reason I don't understand yet, but that's very close.
So the way to solve this would be to fix:
@warn
pmap
tmap
etc.The first two should get MWEs on Zygote. That adjoint should get added to DiffEq, and @dhairyagandhi96 I might want help debugging the last part. This is generally something that would be good to have, and the knock-on effects of fixing this case are likely very valuable (pmap/tmap adjoints are probably more broadly useful, so we should look at that first).
I see, I'm afraid I'm presently incapacitated to dig through the Zygote issue. I do however, have a temporary workaround on GPU whilst training with initial conditions as well. The RHS can be better defined but gradients are being computed successfully:
using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations
# Training hyperparameters
nBatchsize = 30
nEpochs = 10
tspan = 20.0
tsize = 300
nbatches = div(tsize,nBatchsize)
statesize = 2
# ODE solve
function lotka_volterra_func!(du,u,p,t)
du[1] = u[1]*(p[1]-p[2]*u[2])
du[2] = u[2]*(p[3]*u[1]-p[4])
return du
end
function lotka_volterra(du,u,p,t)
dx = Zygote.Buffer(du,size(du)[1])
for i in 1:nbatches
index = (i-1)*statesize
dx[index+1:index+statesize] = lotka_volterra_func!(dx[index+1:index+statesize],u[index+1:index+statesize],p,t)
end
du .= copy(dx)
nothing
end
# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4]
u0 = Float32[0.01, 0.01]
t = range(0.0,tspan,length=tsize)
# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra_func!,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy) |> gpu
# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
u0_train = rand(statesize*nbatches)
# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0_train,(0.0f0,Float32(tspan/nbatches)),p)
#ODE solve to be used for training
function predict_ODE_solve()
return Array(solve(prob2,Tsit5(),saveat=t_batch,reltol=1e-4))
end
function loss_func(data_)
enc_ = NN_encode(data_)
# Solve ODE using initial values from multiple points in enc_.
# Note: reduce(hcat,[..]) gives a mutating arrays error
#enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) #|> gpu
enc_ODE_solve = hcat([predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches]...) |> gpu
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
end
opt = ADAM(0.001)
loss_func(data) # This works
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode,u0_train), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
end
My package list reads:
[c7e460c6] ArgParse v1.1.0
[fbb218c0] BSON v0.2.6
[6e4b80f9] BenchmarkTools v0.5.0
[3895d2a7] CUDAapi v4.0.0
[c5f51814] CUDAdrv v6.3.0
[be33ccc6] CUDAnative v3.1.0
[3a865a2d] CuArrays v2.2.0
[31a5f54b] Debugger v0.6.4
[aae7a2af] DiffEqFlux v1.12.0
[41bf760c] DiffEqSensitivity v6.19.0
[0c46a032] DifferentialEquations v6.14.0
[31c24e10] Distributions v0.23.3
[5789e2e9] FileIO v1.3.0
[587475ba] Flux v0.10.4
[0c68f7d7] GPUArrays v3.4.1
[033835bb] JLD2 v0.1.13
[429524aa] Optim v0.21.0
[1dea7af3] OrdinaryDiffEq v5.39.1
[91a5bcdd] Plots v1.3.6
[8d666b04] PolyChaos v0.2.1
[ee283ea6] Rebugger v0.3.3
[295af30f] Revise v2.7.1
[9f7883ad] Tracker v0.2.6
[e88e6eb3] Zygote v0.4.20
[9a3f8284] Random
hcat(...)
This can be reduce(hcat,[predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches])
Update: Zygote is now compatible with parallelism (https://github.com/FluxML/Zygote.jl/pull/728) so this should be possible now.
Same as user on Discourse: https://discourse.julialang.org/t/error-loaderror-need-an-adjoint-for-constructor-ensemblesolution/42611/7
using DifferentialEquations, Flux
pa = [1.0]
function model1(input)
prob = ODEProblem((u, p, t) -> 1.01u * pa[1], 0.5, (0.0, 1.0))
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 100)
end
Input_time_series = zeros(5, 100)
# loss function
loss(x, y) = Flux.mse(model1(x), y)
data = Iterators.repeated((Input_time_series, 0), 1)
cb = function () # callback function to observe training
println("Tracked Parameters: ", params(pa))
end
opt = ADAM(0.1)
println("Starting to train")
Flux.@epochs 10 Flux.train!(loss, params(pa), data, opt; cb = cb)
is a good MWE
@DhairyaLGandhi can I get help completing this? It just needs some tweaks to the adjoint definitions now:
using OrdinaryDiffEq, DiffEqSensitivity, Flux,
using ZygoteRules
ZygoteRules.@adjoint EnsembleSolution(sim,time,converged) = EnsembleSolution(sim,time,converged), p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) = sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
pa = [1.0]
u0 = [1.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
# loss function
loss() = sum(abs2,1.0.-model1())
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
opt = ADAM(0.05)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test l1 < 10l2
Then EnsembleDistributed()
work because of https://github.com/FluxML/Zygote.jl/pull/728
https://github.com/SciML/DiffEqFlux.jl/issues/321 is somewhat related, since we really need adjoints of the solution types and the literal_getproperty calls.
This almost works if you comment out the warns (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L132) in there.
The two adjoints defined in https://github.com/SciML/DiffEqFlux.jl/issues/279#issuecomment-658825798?
I can take a look in a bit, I'm a bit flocked at the moment.
Edit: you probably want (EnsembleSolution(...), nothing)
for the second adjoint at a quick glance?
alright thanks
https://github.com/SciML/DiffEqBase.jl/pull/557 fixes this. Final test:
using OrdinaryDiffEq, DiffEqSensitivity, Flux
pa = [1.0]
u0 = [3.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
end
# loss function
loss() = sum(abs2,1.0.-Array(model1()))
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1
function model2()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()])
pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1
I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.
Say we have a time series
y
of dimension2x100
and I'd like to solve the ODE over small time intervals[0,10]
using initial conditionsy[:,1:10:end]
. It works fine on the cpu usinghcat([Array(solve(...))]...)
, however using the gpu gives me the error:here is the code:
and here is the current status of packages
Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)