SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 156 forks source link

Mixed CPU+GPU adjoints #401

Closed ChrisRackauckas closed 3 years ago

ChrisRackauckas commented 4 years ago

MWE:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots

u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1)) |>gpu
p1 = initial_params(ann) |>gpu
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]

function dudt_(u,p,t)
    x, y = u
    [cpu(ann(gpu(u),p[1:length(p1)]))[1],p[end-1]*y + p[end]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)

function predict_adjoint(θ)
  gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=InterpolatingAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)

cb = function (θ,l)
  println(l)
  #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
  return false
end

loss1 = loss_adjoint(θ)
res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(), cb = cb, maxiters=10)

This is for the case with heavy scalar nonlinear code + a neural network. We'll need to figure out how to handle the backpass effectively.

ChrisRackauckas commented 3 years ago

Found another case:

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, CUDA, DiffEqSensitivity, Test
CUDA.allowscalar(false) # Makes sure no slow operations are occuring

#generating exogenus signal and output signal
tspan = (0.1f0, Float32(10.0))
tsteps = range(tspan[1], tspan[2], length = 100)
t_vec = collect(tsteps)
ex = vec(ones(Float32,length(tsteps), 1))
f(x) = (atan(8.0 * x - 4.0) + atan(4.0)) / (2.0 * atan(4.0))

function hammerstein_system(u)
    y= zeros(size(u))
    for k in 2:length(u)
        y[k] = 0.2 * f(u[k-1]) + 0.8 * y[k-1]
    end
    return y
end

ex = vec([ones(Float32,50,1) 2*ones(Float32,50,1)]) #exogenus signal
ex = ex'
ode_data = gpu(Float32.(hammerstein_system(ex))) #signal we want to predict

#Define the ode layer
nn_dudt = FastChain(FastDense(2, 8, tanh),FastDense(8, 1))
u0 = Float32[0.0]|> gpu
p = initial_params(nn_dudt)|> gpu

function dudt2(u,p,t,ex)
  nn_dudt(vcat(u,ex[Int(round(t*10))]), p)
end

@test vcat(u0,ex[Int(round(1.0*10))]) isa CuArray

_dudt2(u,p,t) = dudt2(u,p,t,ex)
prob_gpu = ODEProblem(_dudt2, u0, tspan, nothing)

# Runs on a GPU
function predict_neuralode(p)
  _prob_gpu = remake(prob_gpu,p=p)
  gpu(solve(_prob_gpu, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-6))
end

function loss_neuralode(p)
    pred =predict_neuralode(p)
    N = length(pred)
    l = sum(abs2, ode_data[1:N]' .- pred)/N
    return l, pred
end
res0 = DiffEqFlux.sciml_train(loss_neuralode,p ,ADAM(0.01), maxiters=10)

res1 = DiffEqFlux.sciml_train(loss_neuralode,res0.minimizer,ADAM(0.01), maxiters=20)

sol = predict_neuralode(res0.minimizer)
sol = Array(sol)
plot(sol')
plot!(ode_data')
DhairyaLGandhi commented 3 years ago

Simple MWE:

julia> r = rand(Float32, 3) |> gpu;

julia> gradient((x,y) -> sum(vcat(x,y)), r, 5)
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/dhairyalgandhi/.julia/packages/GPUArrays/WV76E/src/host/indexing.jl:41
 [3] getindex at /home/dhairyalgandhi/.julia/packages/GPUArrays/WV76E/src/host/indexing.jl:96 [inlined]
 [4] pull_block_vert at /home/dhairyalgandhi/Zygote.jl/src/lib/array.jl:100 [inlined]
 [5] #437 at /home/dhairyalgandhi/Zygote.jl/src/lib/array.jl:105 [inlined]
 [6] iterate at ./generator.jl:47 [inlined]
 [7] collect_to!(::Array{CuArray{Float32,1},1}, ::Base.Generator{Base.OneTo{Int64},Zygote.var"#437#439"{CuArray{Float32,1},Tuple{CuArray{Float32,1},Int64},Array{Int64,1}}}, ::Int64, ::Int64) at ./array.jl:732
 [8] collect_to_with_first! at ./array.jl:710 [inlined]
 [9] _collect at ./array.jl:704 [inlined]
 [10] collect_similar at ./array.jl:628 [inlined]
 [11] map at ./abstractarray.jl:2162 [inlined]
 [12] #436 at /home/dhairyalgandhi/Zygote.jl/src/lib/array.jl:105 [inlined]
 [13] #2425#back at /home/dhairyalgandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [14] #11 at ./REPL[29]:1 [inlined]
 [15] (::Zygote.var"#41#42"{typeof(∂(#11))})(::Float32) at /home/dhairyalgandhi/Zygote.jl/src/compiler/interface.jl:40
 [16] gradient(::Function, ::CuArray{Float32,1}, ::Vararg{Any,N} where N) at /home/dhairyalgandhi/Zygote.jl/src/compiler/interface.jl:49
 [17] top-level scope at REPL[29]:1

which makes sense, since we would be doing scalar indexing to extract the partial of the number. Seems like this would be alright, but CUDA would get mad at pulling out a single number since vcat(::CuArray, ::Number)::CuArray.

Writing a non-scalar indexing version would look a bit weird and might not be more performant. Maybe we could have better handling for wrapping things in CUDA.@allowscalar ?

Zygote.pull_block_vert(sz, Δ::CuArray, A::Number) = sum(Δ[sz:sz])