AzamatB / ListenAttendSpell.jl

Julia implementation of Listen, Attend and Spell model with Flux.jl
MIT License
1 stars 0 forks source link

Experiment with the ordering of dimensions of the output tensor of the decode function #8

Open AzamatB opened 4 years ago

AzamatB commented 4 years ago

Setup

using Statistics
using BenchmarkTools

# xs = first(Xs_train)
# xs = last(Xs_train)
xs = Xs_train[ceil(Int, (median∘eachindex)(Xs_train))]
Xs = vecofmats2tensor(xs)

encoder_dims = (
   blstm       = (in = 39, out = 8),
   pblstms_out = (8, 8, 8)
)
attention_dim = 8
decoder_out_dim = 8
out_dim = 61
m = LAS(encoder_dims, attention_dim, decoder_out_dim, out_dim)
θ = Flux.params((m.state₀, m.listen, m.key_ψ, m.query_ϕ, m.spell, m.infer))

Hs = m.listen(Xs)
maxT = size(Xs,2)

versions considered

@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
   # ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   ŷs = Buffer(Vector{M}(undef, maxT), false)
   @inbounds for t ∈ eachindex(ŷs)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.query_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      ŷs[t] = prediction = m.infer([decoding; context])
   end
   return copy(ŷs)
end

@inline function decode0(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
   # ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   Ŷs = Buffer(Hs, size(prediction, 1), maxT, batch_size) # D×T×B output tensor
   @inbounds for t ∈ axes(Ŷs, 2)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.query_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      Ŷs[:,t,:] = prediction = m.infer([decoding; context])
   end
   return copy(Ŷs)
end

@inline function decode1(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
   # ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   Ŷs = Buffer(Hs, size(prediction, 1), batch_size, maxT) # D×B×T output tensor
   @inbounds for t ∈ axes(Ŷs, 2)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.query_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      Ŷs[:,:,t] = prediction = m.infer([decoding; context])
   end
   return copy(Ŷs)
end

function gdecode(θ, m, Hs, maxT)
   gradient(θ) do
      sum(sum(decode(m, Hs, maxT)))
   end
end
function gdecode0(θ, m, Hs, maxT)
   gradient(θ) do
      sum(decode0(m, Hs, maxT))
   end
end
function gdecode1(θ, m, Hs, maxT)
   gradient(θ) do
      sum(decode1(m, Hs, maxT))
   end
end

results

julia> reset!(m)

julia> @benchmark decode($m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  220.92 MiB
  allocs estimate:  174733
  --------------
  minimum time:     303.722 ms (6.46% GC)
  median time:      313.715 ms (7.65% GC)
  mean time:        316.241 ms (8.87% GC)
  maximum time:     330.550 ms (11.97% GC)
  --------------
  samples:          16
  evals/sample:     1

julia> reset!(m)

julia> @benchmark decode0($m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  226.36 MiB
  allocs estimate:  174734
  --------------
  minimum time:     302.272 ms (5.79% GC)
  median time:      314.070 ms (6.85% GC)
  mean time:        317.616 ms (8.89% GC)
  maximum time:     343.369 ms (16.08% GC)
  --------------
  samples:          16
  evals/sample:     1

julia> reset!(m)

julia> @benchmark decode1($m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  55.99 MiB
  allocs estimate:  39872
  --------------
  minimum time:     64.015 ms (0.00% GC)
  median time:      66.393 ms (0.00% GC)
  mean time:        72.101 ms (8.44% GC)
  maximum time:     90.428 ms (21.28% GC)
  --------------
  samples:          70
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gdecode($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  923.94 MiB
  allocs estimate:  702845
  --------------
  minimum time:     1.003 s (8.77% GC)
  median time:      1.025 s (9.95% GC)
  mean time:        1.137 s (17.53% GC)
  maximum time:     1.484 s (36.93% GC)
  --------------
  samples:          5
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gdecode0($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  934.91 MiB
  allocs estimate:  707420
  --------------
  minimum time:     1.035 s (9.75% GC)
  median time:      1.047 s (9.63% GC)
  mean time:        1.149 s (17.27% GC)
  maximum time:     1.527 s (38.10% GC)
  --------------
  samples:          5
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gdecode1($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  221.71 MiB
  allocs estimate:  161734
  --------------
  minimum time:     227.258 ms (7.33% GC)
  median time:      237.676 ms (10.23% GC)
  mean time:        238.977 ms (11.00% GC)
  maximum time:     261.431 ms (18.72% GC)
  --------------
  samples:          21
  evals/sample:     1

Conclusion: decode1 version significantly (5x) outperforms the other two, due to cache friendly D×B×T ordering of dimensions in the output tensor. Although it would be nice to keep the ordering of output tensor dimensions consistent with that of the input tensor's as D×T×B, the the performance benefit outweighs it. So decode1 is adopted.

AzamatB commented 4 years ago

Comparing the full forward pass under the setup

using Statistics
using BenchmarkTools

# xs = first(Xs_train)
xs = last(Xs_train)
# xs = Xs_train[ceil(Int, (median∘eachindex)(Xs_train))]
Xs = vecofmats2tensor(xs)

encoder_dims = (
   blstm       = (in = 39, out = 8),
   pblstms_out = (8, 8, 8)
)
attention_dim = 8
decoder_out_dim = 8
out_dim = 61
m = LAS(encoder_dims, attention_dim, decoder_out_dim, out_dim)
θ = Flux.params(m)

f(m, x) = m(x)
function gf(θ, m, x)
   gradient(θ) do
      sum(m(x))
      # (sum ∘ sum ∘ m)(x)
   end
end

reset!(m)
@benchmark f($m, $xs)
reset!(m)
@benchmark f($m, $Xs)

reset!(m)
@benchmark gf($θ, $m, $xs)
reset!(m)
@benchmark gf($θ, $m, $Xs)

with the version of decode

@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
   # ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   Ŷs = Buffer(Hs, size(prediction, 1), batch_size, maxT) # D×B×T output tensor
   @inbounds for t ∈ axes(Ŷs, 2)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.query_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      Ŷs[:,:,t] = prediction = m.infer([decoding; context])
   end
   return copy(Ŷs)
end

the results are

julia> reset!(m)

julia> @benchmark f($m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  111.88 MiB
  allocs estimate:  68070
  --------------
  minimum time:     117.154 ms (0.00% GC)
  median time:      131.959 ms (11.06% GC)
  mean time:        131.458 ms (9.33% GC)
  maximum time:     151.411 ms (15.55% GC)
  --------------
  samples:          39
  evals/sample:     1

julia> reset!(m)

julia> @benchmark f($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  108.79 MiB
  allocs estimate:  79693
  --------------
  minimum time:     121.245 ms (0.00% GC)
  median time:      136.645 ms (9.86% GC)
  mean time:        133.973 ms (7.76% GC)
  maximum time:     146.218 ms (12.97% GC)
  --------------
  samples:          38
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf($θ, $m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  514.75 MiB
  allocs estimate:  1468837
  --------------
  minimum time:     558.208 ms (9.87% GC)
  median time:      608.333 ms (16.84% GC)
  mean time:        630.647 ms (19.17% GC)
  maximum time:     833.633 ms (32.63% GC)
  --------------
  samples:          8
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf($θ, $m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  220.78 MiB
  allocs estimate:  383739
  --------------
  minimum time:     227.210 ms (0.00% GC)
  median time:      263.843 ms (14.47% GC)
  mean time:        262.169 ms (13.96% GC)
  maximum time:     276.629 ms (19.73% GC)
  --------------
  samples:          20
  evals/sample:     1

while with the (previous, i.e. vec of mats output) version of the decode function

@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.key_ψ(reshape(Hs, size(Hs,1), :)), size(m.key_ψ.W, 1), :, batch_size)
   # ψhs = m.key_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   ŷs = Buffer(Vector{M}(undef, maxT), false)
   @inbounds for t ∈ eachindex(ŷs)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.query_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      ŷs[t] = prediction = m.infer([decoding; context])
   end
   return copy(ŷs)
end

the results are

julia> reset!(m)

julia> @benchmark f($m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  487.19 MiB
  allocs estimate:  458671
  --------------
  minimum time:     565.769 ms (6.92% GC)
  median time:      583.121 ms (10.07% GC)
  mean time:        585.167 ms (10.16% GC)
  maximum time:     602.533 ms (10.40% GC)
  --------------
  samples:          9
  evals/sample:     1

julia> reset!(m)

julia> @benchmark f($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  484.09 MiB
  allocs estimate:  470294
  --------------
  minimum time:     564.175 ms (5.48% GC)
  median time:      601.307 ms (10.57% GC)
  mean time:        601.711 ms (10.06% GC)
  maximum time:     624.255 ms (11.36% GC)
  --------------
  samples:          9
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf($θ, $m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  2.25 GiB
  allocs estimate:  3044982
  --------------
  minimum time:     2.773 s (19.52% GC)
  median time:      2.876 s (20.46% GC)
  mean time:        2.876 s (20.46% GC)
  maximum time:     2.979 s (21.34% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf($θ, $m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.96 GiB
  allocs estimate:  1959884
  --------------
  minimum time:     2.417 s (20.49% GC)
  median time:      2.510 s (21.41% GC)
  mean time:        2.510 s (21.41% GC)
  maximum time:     2.603 s (22.27% GC)
  --------------
  samples:          2
  evals/sample:     1

so the new version is 11 times faster for the 3D tensor input, and 5 times faster for the vec of mats input.