Open AzamatB opened 4 years ago
Comparing the full forward pass under the setup
using Statistics
using BenchmarkTools
# xs = first(Xs_train)
xs = last(Xs_train)
# xs = Xs_train[ceil(Int, (median∘eachindex)(Xs_train))]
Xs = vecofmats2tensor(xs)
encoder_dims = (
blstm = (in = 39, out = 8),
pblstms_out = (8, 8, 8)
)
attention_dim = 8
decoder_out_dim = 8
out_dim = 61
m = LAS(encoder_dims, attention_dim, decoder_out_dim, out_dim)
θ = Flux.params(m)
f(m, x) = m(x)
function gf(θ, m, x)
gradient(θ) do
sum(m(x))
# (sum ∘ sum ∘ m)(x)
end
end
reset!(m)
@benchmark f($m, $xs)
reset!(m)
@benchmark f($m, $Xs)
reset!(m)
@benchmark gf($θ, $m, $xs)
reset!(m)
@benchmark gf($θ, $m, $Xs)
with the version of decode
@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
batch_size = size(Hs, 3)
# initialize state for every sequence in a batch
context = repeat(m.state₀.context, 1, batch_size)
decoding = repeat(m.state₀.decoding, 1, batch_size)
prediction = repeat(m.state₀.prediction, 1, batch_size)
# precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
# passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
# ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# check: all(ψhs .≈ eachslice(ψHs; dims=2))
Ŷs = Buffer(Hs, size(prediction, 1), batch_size, maxT) # D×B×T output tensor
@inbounds for t ∈ axes(Ŷs, 2)
# compute decoder state
decoding = m.spell([decoding; prediction; context])::M
# compute query ϕ(sᵢ)
ϕsᵢ = m.query_ϕ(decoding)
# compute energies via batch matrix multiplication
# @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
# check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
# compute attentions weights
αᵢs = softmax(Eᵢs)
# compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
# check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
# predict probability distribution over character alphabet
Ŷs[:,:,t] = prediction = m.infer([decoding; context])
end
return copy(Ŷs)
end
the results are
julia> reset!(m)
julia> @benchmark f($m, $xs)
BenchmarkTools.Trial:
memory estimate: 111.88 MiB
allocs estimate: 68070
--------------
minimum time: 117.154 ms (0.00% GC)
median time: 131.959 ms (11.06% GC)
mean time: 131.458 ms (9.33% GC)
maximum time: 151.411 ms (15.55% GC)
--------------
samples: 39
evals/sample: 1
julia> reset!(m)
julia> @benchmark f($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 108.79 MiB
allocs estimate: 79693
--------------
minimum time: 121.245 ms (0.00% GC)
median time: 136.645 ms (9.86% GC)
mean time: 133.973 ms (7.76% GC)
maximum time: 146.218 ms (12.97% GC)
--------------
samples: 38
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf($θ, $m, $xs)
BenchmarkTools.Trial:
memory estimate: 514.75 MiB
allocs estimate: 1468837
--------------
minimum time: 558.208 ms (9.87% GC)
median time: 608.333 ms (16.84% GC)
mean time: 630.647 ms (19.17% GC)
maximum time: 833.633 ms (32.63% GC)
--------------
samples: 8
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf($θ, $m, $Xs)
BenchmarkTools.Trial:
memory estimate: 220.78 MiB
allocs estimate: 383739
--------------
minimum time: 227.210 ms (0.00% GC)
median time: 263.843 ms (14.47% GC)
mean time: 262.169 ms (13.96% GC)
maximum time: 276.629 ms (19.73% GC)
--------------
samples: 20
evals/sample: 1
while with the (previous, i.e. vec of mats output) version of the decode
function
@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
batch_size = size(Hs, 3)
# initialize state for every sequence in a batch
context = repeat(m.state₀.context, 1, batch_size)
decoding = repeat(m.state₀.decoding, 1, batch_size)
prediction = repeat(m.state₀.prediction, 1, batch_size)
# precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
# passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
# ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# check: all(ψhs .≈ eachslice(ψHs; dims=2))
ŷs = Buffer(Vector{M}(undef, maxT), false)
@inbounds for t ∈ eachindex(ŷs)
# compute decoder state
decoding = m.spell([decoding; prediction; context])::M
# compute query ϕ(sᵢ)
ϕsᵢ = m.query_ϕ(decoding)
# compute energies via batch matrix multiplication
# @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
# check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
# compute attentions weights
αᵢs = softmax(Eᵢs)
# compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
# check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
# predict probability distribution over character alphabet
ŷs[t] = prediction = m.infer([decoding; context])
end
return copy(ŷs)
end
the results are
julia> reset!(m)
julia> @benchmark f($m, $xs)
BenchmarkTools.Trial:
memory estimate: 487.19 MiB
allocs estimate: 458671
--------------
minimum time: 565.769 ms (6.92% GC)
median time: 583.121 ms (10.07% GC)
mean time: 585.167 ms (10.16% GC)
maximum time: 602.533 ms (10.40% GC)
--------------
samples: 9
evals/sample: 1
julia> reset!(m)
julia> @benchmark f($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 484.09 MiB
allocs estimate: 470294
--------------
minimum time: 564.175 ms (5.48% GC)
median time: 601.307 ms (10.57% GC)
mean time: 601.711 ms (10.06% GC)
maximum time: 624.255 ms (11.36% GC)
--------------
samples: 9
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf($θ, $m, $xs)
BenchmarkTools.Trial:
memory estimate: 2.25 GiB
allocs estimate: 3044982
--------------
minimum time: 2.773 s (19.52% GC)
median time: 2.876 s (20.46% GC)
mean time: 2.876 s (20.46% GC)
maximum time: 2.979 s (21.34% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf($θ, $m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.96 GiB
allocs estimate: 1959884
--------------
minimum time: 2.417 s (20.49% GC)
median time: 2.510 s (21.41% GC)
mean time: 2.510 s (21.41% GC)
maximum time: 2.603 s (22.27% GC)
--------------
samples: 2
evals/sample: 1
so the new version is 11 times faster for the 3D tensor input, and 5 times faster for the vec of mats input.
Setup
versions considered
results
Conclusion:
decode1
version significantly (5x) outperforms the other two, due to cache friendlyD×B×T
ordering of dimensions in the output tensor. Although it would be nice to keep the ordering of output tensor dimensions consistent with that of the input tensor's asD×T×B
, the the performance benefit outweighs it. Sodecode1
is adopted.