Open colinxs opened 4 years ago
I investigate this a little bit, I think this is mainly because of the current parameter interface. If you write the dense layer in functional style
using Zygote, Flux, BenchmarkTools
dense(F, W, x, b) = f.(W * x + b)
it's generated pullback (forward
function) should be equivalent to (assume we use tanh
as activation)
function forward_dense(::typeof(tanh), W, x, b)
x1 = W * x .+ b
x2 = tanh.(x1)
x2, function (delta)
Δ = @. delta * sech(x1)^2
Δ * x', W' * Δ, Δ
end
end
then the benchmark results of generated pullback and the manual pullback are
W = rand(10, 10)
x = rand(10)
b = rand(10)
grad = rand(10)
_, generated_back = Zygote.forward(dense, tanh, W, x, b)
_, manual_back = forward_dense(tanh, W, x, b)
@benchmark generated_back($grad)
@benchmark manual_back($grad)
generated:
BenchmarkTools.Trial:
memory estimate: 1.27 KiB
allocs estimate: 6
--------------
minimum time: 541.123 ns (0.00% GC)
median time: 580.636 ns (0.00% GC)
mean time: 683.967 ns (10.54% GC)
maximum time: 10.836 μs (86.50% GC)
--------------
samples: 10000
evals/sample: 187
manual ( I didn't specialize broadcast, that's why it is a bit slower)
BenchmarkTools.Trial:
memory estimate: 1.27 KiB
allocs estimate: 6
--------------
minimum time: 719.571 ns (0.00% GC)
median time: 787.782 ns (0.00% GC)
mean time: 943.974 ns (12.49% GC)
maximum time: 468.792 μs (99.71% GC)
--------------
samples: 10000
evals/sample: 140
but if you use Dense
with Zygote:
m = Dense(10, 10, tanh)
ps = Flux.params(m)
_, flux_back = Zygote.forward(()->m(x), ps)
the generated pullback is super slow
BenchmarkTools.Trial:
memory estimate: 3.58 KiB
allocs estimate: 54
--------------
minimum time: 4.282 μs (0.00% GC)
median time: 4.781 μs (0.00% GC)
mean time: 5.448 μs (4.98% GC)
maximum time: 359.259 μs (95.59% GC)
--------------
samples: 10000
evals/sample: 7
and let's see how's tracker
out = m(x)
@benchmark Tracker.back!($out, $grad, once=false)
BenchmarkTools.Trial:
memory estimate: 2.66 KiB
allocs estimate: 48
--------------
minimum time: 8.498 μs (0.00% GC)
median time: 9.199 μs (0.00% GC)
mean time: 10.034 μs (3.10% GC)
maximum time: 1.154 ms (98.22% GC)
--------------
samples: 10000
evals/sample: 3
thus in principal Zygote should be 2x faster, but in practice, we might need to wait the new parameter interface. (XRef: https://github.com/FluxML/Flux.jl/issues/628)
Just following up: this appears to be fixed (using the Manifest.toml in Flux.jl):
dense(f, W, x, b) = f.(W * x .+ b)
function test(T)
W = rand(T, 128, 128)
x = rand(T, 128)
b = rand(T, 128)
m = Flux.Dense(W,b,tanh)
ps = params(m)
@btime gradient(sum ∘ dense, tanh, $W, $x, $b)
@btime gradient(m -> sum(m($x)), $m)
@btime gradient(() -> sum($m($x)), $ps)
end
julia> test(Float32)
10.527 μs (39 allocations: 67.83 KiB)
8.753 μs (45 allocations: 68.30 KiB)
9.826 μs (47 allocations: 68.39 KiB)
while running @btime gradient(() -> sum($m($x)), $ps)
(also Float32) with Tracker gets:
34.528 μs (93 allocations: 136.14 KiB)
Zygote seems to be significantly slower than Tracker recently. Running
test()
in the following example (assuming I'm doing something flat-out wrong here):yields this for Flux#zygote Zygote#master:
and for Flux@v0.8.3 (only for the calls to
grad2
):Here's my system info: