DesmondYuan / cellbox-julia-dev-6338

1 stars 1 forks source link

Profile of different AD method #1

Open jiweiqi opened 3 years ago

jiweiqi commented 3 years ago
using OrdinaryDiffEq, Flux, Optim, Random, Plots
using DiffEqSensitivity
using Zygote
using ForwardDiff
using LinearAlgebra, Statistics
using ProgressBars, Printf
using Flux.Optimise: update!, ExpDecay
using Flux.Losses: mae
using Distributions
using StatsBase
using BSON: @save, @load

Random.seed!(1234);

# Arguments
is_restart = false;
n_epoch = 10000;
n_plot = 10;  # frequency of callback
alg = Tsit5();

opt = ADAMW(5.f-3, (0.9, 0.999), 1.f-4);
ns = 100;  # number of nodes / species
tfinal = 10.0;
nsample = 10;  # number of samples for each perturbation

function gen_network(m; weight_params=(0., 1.), sparsity=0.)
    w = rand(Normal(weight_params[1], weight_params[2]), (m, m))
    p = [sparsity, 1 - sparsity]
    w .*= sample([0, 1], weights(p), (m, m), replace=true)
    α = abs.(rand(Normal(weight_params[1], weight_params[2]), (m)))
    return hcat(α, w)
end

function cellbox!(du, u, p, t)
    du .= tanh.(view(p, :, 2:ns + 1) * u - μ) - view(p, :, 1) .* u
end

u0 = zeros(ns);
tspan = (0, tfinal);
ts = 0:tspan[2] / nsample:tspan[2];
prob = ODEProblem(cellbox!, u0, tspan, saveat=ts);

p_gold = gen_network(ns; weight_params=(0.0, 1.0), sparsity=0.8);
p = gen_network(ns; weight_params=(0.0, 0.1), sparsity=0);

μ = rand(ns);
sol = solve(prob, alg, u0=u0, p=p_gold);
ode_data = Array(sol);

function predict_neuralode(u0, p)
    _prob = remake(prob, u0=u0, p=p)
    pred = Array(solve(_prob, alg, saveat=ts, sensalg=QuadratureAdjoint()))
    return pred
end
predict_neuralode(u0, p);

using BenchmarkTools

function loss_neuralode()
    pred = predict_neuralode(u0, p)
    loss = mae(ode_data, pred)
    return loss
end
loss_neuralode()

Zygote.gradient(x -> loss_neuralode(x), p)
@benchmark Zygote.gradient(x -> loss_neuralode(x), p)
jiweiqi commented 3 years ago

sensalg=QuadratureAdjoint()

 julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial: 
  memory estimate:  309.96 MiB
  allocs estimate:  5911182
  --------------
  minimum time:     169.876 ms (21.08% GC)
  median time:      178.562 ms (21.55% GC)
  mean time:        181.266 ms (23.14% GC)
  maximum time:     205.677 ms (28.76% GC)
  --------------
  samples:          28
  evals/sample:     1
jiweiqi commented 3 years ago

sensalg=InterpolatingAdjoint()

julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial: 
  memory estimate:  309.96 MiB
  allocs estimate:  5911181
  --------------
  minimum time:     168.827 ms (21.62% GC)
  median time:      173.501 ms (22.09% GC)
  mean time:        177.136 ms (23.68% GC)
  maximum time:     193.983 ms (29.57% GC)
  --------------
  samples:          29
  evals/sample:     1
jiweiqi commented 3 years ago

BacksolveAdjoint()

julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial: 
  memory estimate:  309.96 MiB
  allocs estimate:  5911181
  --------------
  minimum time:     166.015 ms (21.90% GC)
  median time:      168.504 ms (22.19% GC)
  mean time:        173.059 ms (23.47% GC)
  maximum time:     194.327 ms (19.96% GC)
  --------------
  samples:          29
  evals/sample:     1
jiweiqi commented 3 years ago
function cellbox!(du, u, p, t)
    α = view(p, :, 1)
    w = view(p, :, 2:ns + 1)
    du .= tanh.(w * u - μ) - α .* u
    # du .= tanh.(view(p, :, 2:ns + 1) * u - μ) - view(p, :, 1) .* u
end
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial: 
  memory estimate:  347.82 MiB
  allocs estimate:  6643787
  --------------
  minimum time:     185.929 ms (20.90% GC)
  median time:      191.027 ms (20.84% GC)
  mean time:        197.185 ms (23.75% GC)
  maximum time:     219.907 ms (26.68% GC)
  --------------
  samples:          26
  evals/sample:     1
jiweiqi commented 3 years ago

function cellbox!(du, u, p, t)
    x = view(u, 1:ns)
    μ = view(u, (ns + 1):2 * ns)
    α = view(p, :, 1)
    w = view(p, :, 2:ns + 1)
    du[1:ns] .= tanh.(w * x - μ) - α .* x
    du[ns + 1:end] .= 0
    # du .= tanh.(view(p, :, 2:ns + 1) * u - μ) - view(p, :, 1) .* u
end

u0 = vcat(zeros(ns), μ);
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial: 
  memory estimate:  695.78 MiB
  allocs estimate:  15997658
  --------------
  minimum time:     1.081 s (8.44% GC)
  median time:      1.101 s (9.75% GC)
  mean time:        1.099 s (9.60% GC)
  maximum time:     1.113 s (10.28% GC)
  --------------
  samples:          5
  evals/sample:     1