Open jiweiqi opened 3 years ago
sensalg=QuadratureAdjoint()
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial:
memory estimate: 309.96 MiB
allocs estimate: 5911182
--------------
minimum time: 169.876 ms (21.08% GC)
median time: 178.562 ms (21.55% GC)
mean time: 181.266 ms (23.14% GC)
maximum time: 205.677 ms (28.76% GC)
--------------
samples: 28
evals/sample: 1
sensalg=InterpolatingAdjoint()
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial:
memory estimate: 309.96 MiB
allocs estimate: 5911181
--------------
minimum time: 168.827 ms (21.62% GC)
median time: 173.501 ms (22.09% GC)
mean time: 177.136 ms (23.68% GC)
maximum time: 193.983 ms (29.57% GC)
--------------
samples: 29
evals/sample: 1
BacksolveAdjoint()
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial:
memory estimate: 309.96 MiB
allocs estimate: 5911181
--------------
minimum time: 166.015 ms (21.90% GC)
median time: 168.504 ms (22.19% GC)
mean time: 173.059 ms (23.47% GC)
maximum time: 194.327 ms (19.96% GC)
--------------
samples: 29
evals/sample: 1
function cellbox!(du, u, p, t)
α = view(p, :, 1)
w = view(p, :, 2:ns + 1)
du .= tanh.(w * u - μ) - α .* u
# du .= tanh.(view(p, :, 2:ns + 1) * u - μ) - view(p, :, 1) .* u
end
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial:
memory estimate: 347.82 MiB
allocs estimate: 6643787
--------------
minimum time: 185.929 ms (20.90% GC)
median time: 191.027 ms (20.84% GC)
mean time: 197.185 ms (23.75% GC)
maximum time: 219.907 ms (26.68% GC)
--------------
samples: 26
evals/sample: 1
function cellbox!(du, u, p, t)
x = view(u, 1:ns)
μ = view(u, (ns + 1):2 * ns)
α = view(p, :, 1)
w = view(p, :, 2:ns + 1)
du[1:ns] .= tanh.(w * x - μ) - α .* x
du[ns + 1:end] .= 0
# du .= tanh.(view(p, :, 2:ns + 1) * u - μ) - view(p, :, 1) .* u
end
u0 = vcat(zeros(ns), μ);
julia> @benchmark Zygote.gradient(x -> loss_neuralode(x), p)
BenchmarkTools.Trial:
memory estimate: 695.78 MiB
allocs estimate: 15997658
--------------
minimum time: 1.081 s (8.44% GC)
median time: 1.101 s (9.75% GC)
mean time: 1.099 s (9.60% GC)
maximum time: 1.113 s (10.28% GC)
--------------
samples: 5
evals/sample: 1