SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
858 stars 154 forks source link

Performance tips for Neural SDEs #312

Open metanoid opened 4 years ago

metanoid commented 4 years ago

The docs for DifferentialEquations.jl include a nice FAQ section with performance tips here. Is there anything like that for common performance pitfalls for sciml_train specifically?

I'm struggling to get good performance on the code below and I'm looking for resources that would help me understand what else to try.


using OrdinaryDiffEq
using Plots

using DiffEqFlux
using Flux
using Optim: LBFGS
using DifferentialEquations
using DiffEqSensitivity

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
trueprob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(trueprob, Tsit5(), saveat = t)) .+ (randn(2, datasize) .* 0.3) # noise added after solve

layers = FastChain(FastDense(2,50, tanh),FastDense(50,2))
p_layers = initial_params(layers)
θ = [u0; p_layers; rand(2)]

function dudt(u, p, t)
    return layers(u.^3, @view p[1:252])
end
function noisy(u,p,t)
    noise_params = @view p[253:254]
    return noise_params .* u .* 0.1
end

prob = SDEProblem(dudt, noisy, u0, tspan)

predict_n_ode(θ) = solve(prob, u0 = θ[1:2], p = θ[3:end], saveat=t, sensealg=BacksolveAdjoint())

function loss_n_ode(θ)
    return sum(abs2, ode_data .- predict_n_ode(θ))
end
loss_n_ode(θ)

cb = function (θ, loss)
    display(loss)
    # plot current prediction against data
    # pl = scatter(t, ode_data[1,:], label = "data")
    # scatter!(pl, t, pred[1,:], label = "prediction")
    # display(plot(pl))
    return false
end

cb(θ, loss_n_ode(θ)...)

data = Iterators.repeated((), 10)

# res1 = sciml_train(loss_n_ode, θ, ADAM(0.05), data; cb=cb, maxiters=2)
# cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)
# 
# res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
# cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)

@time dudt(u0, θ, t) #  0.000018 seconds (52 allocations: 9.812 KiB)
@time noisy(u0, θ, t) #   0.000005 seconds (8 allocations: 368 bytes)
@time predict_n_ode(θ) #  0.009365 seconds (184.82 k allocations: 13.316 MiB)
@code_warntype predict_n_ode(θ) #  0.009365 seconds (184.82 k allocations: 13.316 MiB)
@time loss_n_ode(θ) # 0.010592 seconds (182.07 k allocations: 13.118 MiB)
@time cb(θ, loss_n_ode(θ)...) #0.050798 seconds (212.00 k allocations: 14.597 MiB)
@time sciml_train(loss_n_ode, θ, ADAM(0.05), data; cb=cb, maxiters=2) #190.420638 seconds (1.40 G allocations: 108.593 GiB, 12.50% gc time)
ChrisRackauckas commented 4 years ago

You're using a setup which is still not documented. Adjoints of SDEs were just created and released like last week. We'll need to take another round of benchmarking to even know ourselves what the performance tips are. I think you're hitting a performance snag somewhere but I don't know here it is, so it would be good to do some profiling here.

Pinging @frankschae

frankschae commented 4 years ago

I will hopefully next week have some more time to profile the performance. Looking at your code example, it could be that the bottleneck currently comes from the computation of the parameter jacobian for the noise process. Eq. (14) in https://arxiv.org/pdf/2001.01328.pdf defines the adjoint for diagonal noise processes that is needed here. From all available 256 parameters only the last two of them are actually used in the noise term in your example. However, we currently compute the full vjp column by column (i.e., each column consisting of 256 values where in the example 254 of them can be fixed to zero..).

ChrisRackauckas commented 4 years ago

Yes, we need to benchmark in comparison to reverse mode AD appraoches as well.

However, we currently compute the full vjp column by column (i.e., each column consisting of 256 values where in the example 254 of them can be fixed to zero..).

Indeed, I wonder if we could better utilize sparsity detection here.

frankschae commented 4 years ago

I finally started to profile the performance. @ChrisRackauckas should the code and plots for this go directly in here? I have first cleaned a few Float64 vs Float32 issues. For example in here:

function noisy(u,p,t)
    noise_params = @view p[253:254]
    return noise_params .* u .* 0.1
end

The *0.1 is a Problem and Tracker will even throw an error that it doesn't know how to do the conversion. Also in the definition of the parameters, we should use:

θ = [u0; p_layers; rand(Float32,2)] 

to make the type consistent. That appeared to be a first performance killer.

When using inplace functions, a fully non-allocating form in the spirit of https://github.com/SciML/DiffEqGPU.jl is more than 5 times faster on my machine:

true_p = Float32[-0.1,2.0,-2.0,-0.1,0.1,0.1]

function trueSDEfunc(du, u, p, t)
  @inbounds begin
    du[1] = p[1]*u[1]^3 + p[3]*u[2]^3
    du[2] = p[2]*u[1]^3 + p[4]*u[2]^3
  end
  nothing
end

function true_noise_func(du, u, p, t)
  @inbounds begin
    du[1] = p[5]*u[1]
    du[2] = p[6]*u[2]
  end
  nothing
end

From a first timing it looks then to me that the adjoint is already pretty fast (though I realized that the gradients don't match perfectly).

@btime ForwardDiff.gradient(lossAD,θ) #123.140 ms (173165 allocations: 173.45 MiB)
@btime ReverseDiff.gradient(lossAD,θ) # 46.798 ms (880397 allocations: 29.73 MiB)
@btime Zygote.gradient(
    θ->sum(abs2,sde_data .-Array(solve(prob,EulerHeun(),dt=dtsolver,p=θ[3:end],saveat=tsteps,sensealg=TrackerAdjoint()))),
    θ)[1] #70.068 ms (276299 allocations: 11.52 MiB)
@btime Zygote.gradient(θ->sum(abs2,sde_data .- Array(solve(prob,EulerHeun(),dt=dtsolver,u0=θ[1:2],p=θ[3:end],
  saveat=tsteps,sensealg=BacksolveAdjoint()))),θ)[1] #19.875 ms (93166 allocations: 7.45 MiB)
ChrisRackauckas commented 4 years ago

@smitch151 follow this.

ChrisRackauckas commented 4 years ago

@frankschae looking great! I think for robustness the InterpolatingAdjoint version needs to be completed, since the BacksolveAdjoint can always be a little shaky in terms of what it computes. But other than that, it looks to be the new best option which is great.

What's the current bottleneck computation? Anything that can be done there?

When using inplace functions, a fully non-allocating form in the spirit of https://github.com/SciML/DiffEqGPU.jl is more than 5 times faster on my machine:

Note that Zygote VJPs will be much faster on neural networks in out-of-place form, so you'll want to do some timings on the neural SDE example as well and it should have different performance characteristics.

ChrisRackauckas commented 4 years ago

I finally started to profile the performance. @ChrisRackauckas should the code and plots for this go directly in here?

Yes, it would be nice to have a record of this.