SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

Multiple ODE solves from a single time series on GPU #279

Closed mkalia94 closed 4 years ago

mkalia94 commented 4 years ago

I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.

Say we have a time series y of dimension 2x100 and I'd like to solve the ODE over small time intervals [0,10] using initial conditions y[:,1:10:end]. It works fine on the cpu using hcat([Array(solve(...))]...), however using the gpu gives me the error:

ERROR: LoadError: CuArray only supports bits types
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] CuArrays.CuArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}},1,P} where P(::UndefInitializer, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:106
 [3] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:139
 [4] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Int64) at ./abstractarray.jl:628
 [5] similar(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/tracked.jl:325
 [6] Zygote.Buffer(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/Zygote/YeCEW/src/tools/buffer.jl:42
 [7] lotka_volterra(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:15
[8] (::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::Vararg{Any,N} where N) at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/diffeqfunction.jl:248
 [9] (::DiffEqSensitivity.var"#67#74"{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:113
 [10] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:207
 [11] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:204
 [12] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}; quad::Bool, noiseterm::Bool) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:111
 [13] adjointdiffcache at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:26 [inlined]
 [14] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::Array{Float32,1}, ::Nothing, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:37
 [15] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; checkpoints::Array{Float32,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:115
 [16] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:17
 [17] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:reltol,),Tuple{Float64}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:6
 [18] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/concrete_solve.jl:107
 [19] #512#back at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [20] #174 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
 [21] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{DiffEqBase.var"#512#back#457"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [22] #solve#443 at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/solve.jl:69 [inlined]
 [23] (::typeof(∂(#solve#443)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#174#175"{typeof(∂(#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
 [25] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [26] (::typeof(∂(solve##kw)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [27] predict_ODE_solve at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:47 [inlined]
 [28] (::typeof(∂(predict_ODE_solve)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [29] #41 at ./none:0 [inlined]
 [30] (::typeof(∂(λ)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [31] #1187 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172 [inlined]
 [32] #3 at ./generator.jl:36 [inlined]
 [33] iterate at ./generator.jl:47 [inlined]
 [34] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}}},Base.var"#3#4"{Zygote.var"#1187#1191"}}) at ./array.jl:665
 [35] map at ./abstractarray.jl:2154 [inlined]
 [36] (::Zygote.var"#1186#1190"{Array{typeof(∂(λ)),1}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172
 [37] (::Zygote.var"#1194#1195"{Zygote.var"#1186#1190"{Array{typeof(∂(λ)),1}}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:187
 [38] loss_func at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:54 [inlined]
 [39] (::typeof(∂(loss_func)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [40] #16 at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:85 [inlined]
 [41] (::typeof(∂(λ)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#49#50"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [43] gradient(::Function, ::Params) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [44] macro expansion at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:84 [inlined]
 [45] macro expansion at /home/manu/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [46] train!(::typeof(loss_func), ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [47] train!(::Function, ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM) at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
 [48] top-level scope at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:69
 [49] include(::String) at ./client.jl:439
in expression starting at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:66

here is the code:

using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations

# Training hyperparameters
nBatchsize = 10
nEpochs = 10
tspan = 5.0
tsize = 100
nbatches = div(tsize,nBatchsize)

# ODE solve
function lotka_volterra(du,u,p,t)
    dx = Zygote.Buffer(u,size(u)[1])
    dx[1] = u[1]*(p[1]-p[2]*u[2])
    dx[2] = u[2]*(p[3]*u[1]-p[4])
    du .= copy(dx)
    nothing
end

# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4] 
u0 = Float32[0.01, 0.01]

t = range(0.0,tspan,length=tsize)

# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data

data = Float32.(yy) |> gpu

# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu

# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0,(0.0f0,Float32(tspan/nbatches)),p)

# ODE solve to be used for training
function predict_ODE_solve(x)
    return Array(solve(prob2,Tsit5(),u0=x,saveat=t_batch,reltol=1e-4)) 
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    # Solve ODE using initial values from multiple points in enc_.
    # Note: reduce(hcat,[..]) gives a mutating arrays error
    enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) |> gpu
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

opt = ADAM(0.001)

loss_func(data) # This works

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

and here is the current status of packages

[c7e460c6] ArgParse v1.1.0
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [c5f51814] CUDAdrv v6.3.0
  [be33ccc6] CUDAnative v3.1.0
  [3a865a2d] CuArrays v2.2.0
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.12.0
  [41bf760c] DiffEqSensitivity v6.19.0
  [0c46a032] DifferentialEquations v6.14.0
  [31c24e10] Distributions v0.23.3
  [5789e2e9] FileIO v1.3.0
  [587475ba] Flux v0.10.4
  [0c68f7d7] GPUArrays v3.4.1
  [033835bb] JLD2 v0.1.13
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.39.1
  [91a5bcdd] Plots v1.3.6
  [8d666b04] PolyChaos v0.2.1
  [ee283ea6] Rebugger v0.3.3
  [295af30f] Revise v2.7.1
  [9f7883ad] Tracker v0.2.6
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)

ChrisRackauckas commented 4 years ago

Hey, This is fundamentally the wrong approach. What you're looking for is https://github.com/SciML/DiffEqGPU.jl . DiffEqGPU.jl will take small ODE problems and solve them simultaneously for different parameters and initial conditions on a GPU. You'll want to mix this with forward mode AD (i.e. use sensealg=ForwardDiffSensitivity()) which will be more efficient for parameter numbers of this amount per ODE.

mkalia94 commented 4 years ago

Ah I see. I tried DiffEqGPU.jl to no avail, I opened a separate issue there regarding GPU usage SciML/DiffEqGPU.jl#57 I get the following error when I try to use EnsembleThreads() with or without ForwardDiffSensitivity

 Need an adjoint for constructor EnsembleSolution{Float32,3,Array{ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Tuple{Float32,Float32,Float32},ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},1}}. Gradient is of type Array{Float64,3}

while using the following code:

using Pkg
Pkg.activate(".")
using DiffEqGPU, DifferentialEquations, Flux, DiffEqSensitivity
function lorenz(du,u,p,t)
 @inbounds begin
     du[1] = p[1]*(u[2]-u[1])
     du[2] = u[1]*(p[2]-u[3]) - u[2]
     du[3] = u[1]*u[2] - p[3]*u[3]
 end
 nothing
end

nBatchsize = 10
nEpochs = 10
tspan = 100.0
tsize = 1000
nbatches = div(tsize,nBatchsize)
t = range(0.0,tspan,length=tsize)

NN_encode = Chain(Dense(3,10,tanh),Dense(10,10,tanh),Dense(10,3,tanh))
NN_decode = Chain(Dense(3,10),Dense(10,10,tanh),Dense(10,3,tanh))

u0 = Float32[1.0, 0.0, 0.0]
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy)

args_ = Dict()

t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lorenz,u0,(0.0f0,Float32(tspan/nbatches)),p)

function ensemble_solve(x)
    prob_func = (prob,i,repeat) -> remake(prob,u0=x[:,i])
    monteprob = EnsembleProblem(prob2, prob_func = prob_func)
    return  Array(solve(monteprob,Tsit5(),EnsembleThreads(),trajectories=nbatches,saveat=t_batch,sensealg=ForwardDiffSensitivity()))
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    enc_ODE_solve = hcat([ensemble_solve(enc_[:,1:nBatchsize:end])[:,:,i] for i in 1:nbatches]...)
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

ensemble_solve(data)

loss_func(data)

opt = ADAM(0.001)

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end
ChrisRackauckas commented 4 years ago

Hmm, Zygote is having issues here.

using ZygoteRules
ZygoteRules.@adjoint function EnsembleSolution(sim,time,converged)
    EnsembleSolution(sim,time,converged),y->(y,nothing,nothing)
end

gets you pretty far. The next issue you hit is that Zygote doesn't seem to work with @threads, so I just changed it to EnsembleSerial() to keep moving. Then, Zygote wasn't able to compile code with @warn (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L125-L152) because that has a try-catch in there, so I commented those out. That made it start calling the adjoint equation, which then errored for a reason I don't understand yet, but that's very close.

So the way to solve this would be to fix:

The first two should get MWEs on Zygote. That adjoint should get added to DiffEq, and @dhairyagandhi96 I might want help debugging the last part. This is generally something that would be good to have, and the knock-on effects of fixing this case are likely very valuable (pmap/tmap adjoints are probably more broadly useful, so we should look at that first).

mkalia94 commented 4 years ago

I see, I'm afraid I'm presently incapacitated to dig through the Zygote issue. I do however, have a temporary workaround on GPU whilst training with initial conditions as well. The RHS can be better defined but gradients are being computed successfully:

using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations

# Training hyperparameters
nBatchsize = 30
nEpochs = 10
tspan = 20.0
tsize = 300
nbatches = div(tsize,nBatchsize)
statesize = 2

# ODE solve
function lotka_volterra_func!(du,u,p,t)
    du[1] = u[1]*(p[1]-p[2]*u[2])
    du[2] = u[2]*(p[3]*u[1]-p[4])
    return du
end

function lotka_volterra(du,u,p,t)
    dx = Zygote.Buffer(du,size(du)[1])
    for i in 1:nbatches
        index = (i-1)*statesize
        dx[index+1:index+statesize] = lotka_volterra_func!(dx[index+1:index+statesize],u[index+1:index+statesize],p,t)
    end
    du .= copy(dx)
    nothing
end

# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4] 
u0 = Float32[0.01, 0.01]

t = range(0.0,tspan,length=tsize)

# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra_func!,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data

data = Float32.(yy) |> gpu

# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
u0_train = rand(statesize*nbatches)

# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0_train,(0.0f0,Float32(tspan/nbatches)),p)

#ODE solve to be used for training
function predict_ODE_solve()
    return Array(solve(prob2,Tsit5(),saveat=t_batch,reltol=1e-4)) 
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    # Solve ODE using initial values from multiple points in enc_.
    # Note: reduce(hcat,[..]) gives a mutating arrays error
    #enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) #|> gpu
    enc_ODE_solve = hcat([predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches]...) |> gpu
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

opt = ADAM(0.001)

loss_func(data) # This works

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode,u0_train), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

My package list reads:

  [c7e460c6] ArgParse v1.1.0
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [c5f51814] CUDAdrv v6.3.0
  [be33ccc6] CUDAnative v3.1.0
  [3a865a2d] CuArrays v2.2.0
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.12.0
  [41bf760c] DiffEqSensitivity v6.19.0
  [0c46a032] DifferentialEquations v6.14.0
  [31c24e10] Distributions v0.23.3
  [5789e2e9] FileIO v1.3.0
  [587475ba] Flux v0.10.4
  [0c68f7d7] GPUArrays v3.4.1
  [033835bb] JLD2 v0.1.13
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.39.1
  [91a5bcdd] Plots v1.3.6
  [8d666b04] PolyChaos v0.2.1
  [ee283ea6] Rebugger v0.3.3
  [295af30f] Revise v2.7.1
  [9f7883ad] Tracker v0.2.6
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 
ChrisRackauckas commented 4 years ago

hcat(...)

This can be reduce(hcat,[predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches])

ChrisRackauckas commented 4 years ago

Update: Zygote is now compatible with parallelism (https://github.com/FluxML/Zygote.jl/pull/728) so this should be possible now.

ChrisRackauckas commented 4 years ago

Same as user on Discourse: https://discourse.julialang.org/t/error-loaderror-need-an-adjoint-for-constructor-ensemblesolution/42611/7

ChrisRackauckas commented 4 years ago
using DifferentialEquations, Flux

pa = [1.0]

function model1(input) 
  prob = ODEProblem((u, p, t) -> 1.01u * pa[1], 0.5, (0.0, 1.0))

  function prob_func(prob, i, repeat)
    remake(prob, u0 = rand() * prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 100)
end

Input_time_series = zeros(5, 100)
# loss function
loss(x, y) = Flux.mse(model1(x), y)

data = Iterators.repeated((Input_time_series, 0), 1)

cb = function () # callback function to observe training
  println("Tracked Parameters: ", params(pa))
end

opt = ADAM(0.1)
println("Starting to train")
Flux.@epochs 10 Flux.train!(loss, params(pa), data, opt; cb = cb)

is a good MWE

ChrisRackauckas commented 4 years ago

@DhairyaLGandhi can I get help completing this? It just needs some tweaks to the adjoint definitions now:

using OrdinaryDiffEq, DiffEqSensitivity, Flux,

using ZygoteRules
ZygoteRules.@adjoint EnsembleSolution(sim,time,converged) = EnsembleSolution(sim,time,converged), p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) = sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)

pa = [1.0]
u0 = [1.0]
function model1()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = rand() .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end

# loss function
loss() = sum(abs2,1.0.-model1())

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end

opt = ADAM(0.05)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test l1 < 10l2

Then EnsembleDistributed() work because of https://github.com/FluxML/Zygote.jl/pull/728

ChrisRackauckas commented 4 years ago

https://github.com/SciML/DiffEqFlux.jl/issues/321 is somewhat related, since we really need adjoints of the solution types and the literal_getproperty calls.

This almost works if you comment out the warns (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L132) in there.

DhairyaLGandhi commented 4 years ago

The two adjoints defined in https://github.com/SciML/DiffEqFlux.jl/issues/279#issuecomment-658825798?

I can take a look in a bit, I'm a bit flocked at the moment.

Edit: you probably want (EnsembleSolution(...), nothing) for the second adjoint at a quick glance?

ChrisRackauckas commented 4 years ago

alright thanks

ChrisRackauckas commented 4 years ago

https://github.com/SciML/DiffEqBase.jl/pull/557 fixes this. Final test:

using OrdinaryDiffEq, DiffEqSensitivity, Flux
pa = [1.0]
u0 = [3.0]
function model1()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
end

# loss function
loss() = sum(abs2,1.0.-Array(model1()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end

opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1

function model2()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()])

pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1