Open AzamatB opened 4 years ago
Results for computing energy vector by taking the diagonal of the result of multiplication of query and keys matrices (f0
) vs their columnwise dot products (f1
):
function f0(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function f1(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
batch_axis = axes(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢ = view.(Ref(m.attention_ϕ(m.state.decoding)), :, batch_axis)
# compute energies
Eᵢs = (ψh -> ϕsᵢ .⋅ view.(Ref(ψh), :, batch_axis)).(ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function gf0(m, Xs, θ)
gradient(θ) do
sum(sum(f0(m, Xs)))
end
end
function gf1(m, Xs, θ)
gradient(θ) do
sum(sum(f1(m, Xs)))
end
end
Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):
julia> reset!(m)
julia> @benchmark f0($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.72 GiB
allocs estimate: 54890
--------------
minimum time: 2.461 s (7.70% GC)
median time: 2.514 s (8.95% GC)
mean time: 2.511 s (8.64% GC)
maximum time: 2.556 s (9.25% GC)
--------------
samples: 3
evals/sample: 1
julia> reset!(m)
julia> @benchmark f1($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.40 GiB
allocs estimate: 643228
--------------
minimum time: 2.222 s (7.58% GC)
median time: 2.251 s (8.09% GC)
mean time: 2.244 s (7.92% GC)
maximum time: 2.261 s (8.09% GC)
--------------
samples: 3
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf0($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 5.75 GiB
allocs estimate: 362721
--------------
minimum time: 8.926 s (34.64% GC)
median time: 8.926 s (34.64% GC)
mean time: 8.926 s (34.64% GC)
maximum time: 8.926 s (34.64% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf1($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 97.39 GiB
allocs estimate: 18548950
--------------
minimum time: 95.674 s (28.28% GC)
median time: 95.674 s (28.28% GC)
mean time: 95.674 s (28.28% GC)
maximum time: 95.674 s (28.28% GC)
--------------
samples: 1
evals/sample: 1
Changing from view
s to getindex
es in energy computation step only increased the total time to 110.982 seconds.
Conclusion: Columnwise dot products (f1
) both via view
or getindex
are much much (more than 10x) slower (reported here) than taking the diagonal of the result of multiplication of query and keys matrices (f0
) when it comes to computing gradients.
Changing
ŷs = map(1:maxT) do _
to
ŷs = broadcast(1:maxT) do _
did not win any performance.
Changing αᵢs = softmax(hcat(Eᵢs...)')
to
αᵢs = softmax(stack(Eᵢs)')
slighly reduces allocations, so was given preference toαᵢs = softmax(reduce(hcat, Eᵢs)')
throws ERROR: Can't differentiate gc_preserve_end expression
αᵢs = softmax(reduce(vcat, Eᵢs'))
throws ERROR: Can't differentiate loopinfo expression
also with D × B × T
ordering of input dimensions:
αᵢs = softmax(stack(Eᵢs); dims=2)
slightly increases the runtimeαᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
still throws ERROR: Can't differentiate gc_preserve_end expression
Implementations of 6 different versions of the forward pass, one for each of the permutation of
D
(input dimension),T
(time duration) andB
(batch size):Was used smallish size neural net with the following dimensions
Results for
xs = last(Xs_train); Xs = vecofmats2tensor(xs)
:Results for
xs = first(Xs_train); Xs = vecofmats2tensor(xs)
:Conclusion:
D × T × B
andD × B × T
orderings seem to be the most efficient ones, although not much difference between all of the versions and asT
grows all computations are dominated by garbage collection and speed differences almost vanish.