AzamatB / ListenAttendSpell.jl

Julia implementation of Listen, Attend and Spell model with Flux.jl
MIT License
1 stars 0 forks source link

Full experiment with ordering of dimensions in the forward pass for the input of type DenseArray{<:Real,3} #2

Open AzamatB opened 4 years ago

AzamatB commented 4 years ago

Implementations of 6 different versions of the forward pass, one for each of the permutation of D(input dimension), T(time duration) and B(batch size):

"""
D × T × B
"""
function fdtb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
D × B × T
"""
function fdbt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = permutedims(m.listen(Xs), [1,3,2])
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, :, axes(Hs,3)))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
T × D × B
"""
function ftdb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,1,3])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, :,1, batch_size) .* Hs; dims=1); dims=1)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
B × T × D
"""
function fbtd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,2,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=2); dims=2))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
T × B × D
"""
function ftbd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,3,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=1); dims=1))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
B × D × T
"""
function fbdt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,1,2])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(reshape(αᵢs, batch_size, 1, :) .* Hs; dims=3); dims=3))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function gfdtb(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdtb(m, Xs)))
   end
end
function gfdbt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdbt(m, Xs)))
   end
end
function gftdb(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftdb(m, Xs)))
   end
end
function gfbtd(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbtd(m, Xs)))
   end
end
function gftbd(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftbd(m, Xs)))
   end
end
function gfbdt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbdt(m, Xs)))
   end
end

Was used smallish size neural net with the following dimensions

encoder_dims = (
   blstm       = (in = 39, out = 64),
   pblstms_out = (64, 64, 64)
)
attention_dim = 128
decoder_out_dims = (128, 64)
m = LAS(encoder_dims, attention_dim, decoder_out_dims, out_dim)
θ = Flux.params(m)
using BenchmarkTools

Results for xs = last(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.94 GiB
  allocs estimate:  306948
  --------------
  minimum time:     3.294 s (11.02% GC)
  median time:      3.295 s (11.20% GC)
  mean time:        3.295 s (11.20% GC)
  maximum time:     3.296 s (11.39% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.94 GiB
  allocs estimate:  318347
  --------------
  minimum time:     3.485 s (10.91% GC)
  median time:      3.514 s (10.76% GC)
  mean time:        3.514 s (10.76% GC)
  maximum time:     3.543 s (10.61% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.54 GiB
  allocs estimate:  367033
  --------------
  minimum time:     4.517 s (10.70% GC)
  median time:      4.523 s (10.57% GC)
  mean time:        4.523 s (10.57% GC)
  maximum time:     4.529 s (10.44% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  386793
  --------------
  minimum time:     4.498 s (10.76% GC)
  median time:      4.501 s (10.51% GC)
  mean time:        4.501 s (10.51% GC)
  maximum time:     4.504 s (10.26% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  366273
  --------------
  minimum time:     4.461 s (10.82% GC)
  median time:      4.469 s (10.95% GC)
  mean time:        4.469 s (10.95% GC)
  maximum time:     4.477 s (11.07% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  387553
  --------------
  minimum time:     5.083 s (9.19% GC)
  median time:      5.083 s (9.19% GC)
  mean time:        5.083 s (9.19% GC)
  maximum time:     5.083 s (9.19% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  17.50 GiB
  allocs estimate:  2662077
  --------------
  minimum time:     30.478 s (64.75% GC)
  median time:      30.478 s (64.75% GC)
  mean time:        30.478 s (64.75% GC)
  maximum time:     30.478 s (64.75% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  17.51 GiB
  allocs estimate:  2689455
  --------------
  minimum time:     30.562 s (64.84% GC)
  median time:      30.562 s (64.84% GC)
  mean time:        30.562 s (64.84% GC)
  maximum time:     30.562 s (64.84% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.29 GiB
  allocs estimate:  2968508
  --------------
  minimum time:     28.648 s (57.85% GC)
  median time:      28.648 s (57.85% GC)
  mean time:        28.648 s (57.85% GC)
  maximum time:     28.648 s (57.85% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  3001183
  --------------
  minimum time:     28.857 s (57.49% GC)
  median time:      28.857 s (57.49% GC)
  mean time:        28.857 s (57.49% GC)
  maximum time:     28.857 s (57.49% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  2964703
  --------------
  minimum time:     28.671 s (57.99% GC)
  median time:      28.671 s (57.99% GC)
  mean time:        28.671 s (57.99% GC)
  maximum time:     28.671 s (57.99% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  3008029
  --------------
  minimum time:     28.963 s (57.72% GC)
  median time:      28.963 s (57.72% GC)
  mean time:        28.963 s (57.72% GC)
  maximum time:     28.963 s (57.72% GC)
  --------------
  samples:          1
  evals/sample:     1

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  --------------
  minimum time:     2.456 s (7.91% GC)
  median time:      2.509 s (9.16% GC)
  mean time:        2.504 s (8.79% GC)
  maximum time:     2.548 s (9.26% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  57409
  --------------
  minimum time:     2.532 s (8.33% GC)
  median time:      2.604 s (8.87% GC)
  mean time:        2.604 s (8.87% GC)
  maximum time:     2.676 s (9.38% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.28 GiB
  allocs estimate:  74143
  --------------
  minimum time:     3.719 s (7.99% GC)
  median time:      3.754 s (8.27% GC)
  mean time:        3.754 s (8.27% GC)
  maximum time:     3.789 s (8.54% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  78511
  --------------
  minimum time:     3.604 s (8.44% GC)
  median time:      3.623 s (8.55% GC)
  mean time:        3.623 s (8.55% GC)
  maximum time:     3.643 s (8.66% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  73975
  --------------
  minimum time:     3.684 s (8.50% GC)
  median time:      3.710 s (8.67% GC)
  mean time:        3.710 s (8.67% GC)
  maximum time:     3.736 s (8.84% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  78679
  --------------
  minimum time:     3.624 s (8.84% GC)
  median time:      3.636 s (8.83% GC)
  mean time:        3.636 s (8.83% GC)
  maximum time:     3.647 s (8.81% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  --------------
  minimum time:     8.457 s (34.28% GC)
  median time:      8.457 s (34.28% GC)
  mean time:        8.457 s (34.28% GC)
  maximum time:     8.457 s (34.28% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  368787
  --------------
  minimum time:     8.571 s (33.81% GC)
  median time:      8.571 s (33.81% GC)
  mean time:        8.571 s (33.81% GC)
  maximum time:     8.571 s (33.81% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.48 GiB
  allocs estimate:  440883
  --------------
  minimum time:     10.769 s (34.73% GC)
  median time:      10.769 s (34.73% GC)
  mean time:        10.769 s (34.73% GC)
  maximum time:     10.769 s (34.73% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  448102
  --------------
  minimum time:     10.603 s (35.30% GC)
  median time:      10.603 s (35.30% GC)
  mean time:        10.603 s (35.30% GC)
  maximum time:     10.603 s (35.30% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  440038
  --------------
  minimum time:     10.768 s (34.75% GC)
  median time:      10.768 s (34.75% GC)
  mean time:        10.768 s (34.75% GC)
  maximum time:     10.768 s (34.75% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  449619
  --------------
  minimum time:     10.641 s (35.38% GC)
  median time:      10.641 s (35.38% GC)
  mean time:        10.641 s (35.38% GC)
  maximum time:     10.641 s (35.38% GC)
  --------------
  samples:          1
  evals/sample:     1

Conclusion: D × T × B and D × B × T orderings seem to be the most efficient ones, although not much difference between all of the versions and as T grows all computations are dominated by garbage collection and speed differences almost vanish.

AzamatB commented 4 years ago

Results for computing energy vector by taking the diagonal of the result of multiplication of query and keys matrices (f0) vs their columnwise dot products (f1):

function f0(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function f1(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   batch_axis = axes(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢ = view.(Ref(m.attention_ϕ(m.state.decoding)), :, batch_axis)
      # compute energies
      Eᵢs = (ψh -> ϕsᵢ .⋅ view.(Ref(ψh), :, batch_axis)).(ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function gf0(m, Xs, θ)
   gradient(θ) do
      sum(sum(f0(m, Xs)))
   end
end
function gf1(m, Xs, θ)
   gradient(θ) do
      sum(sum(f1(m, Xs)))
   end
end

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m)

julia> @benchmark f0($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  --------------
  minimum time:     2.461 s (7.70% GC)
  median time:      2.514 s (8.95% GC)
  mean time:        2.511 s (8.64% GC)
  maximum time:     2.556 s (9.25% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark f1($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.40 GiB
  allocs estimate:  643228
  --------------
  minimum time:     2.222 s (7.58% GC)
  median time:      2.251 s (8.09% GC)
  mean time:        2.244 s (7.92% GC)
  maximum time:     2.261 s (8.09% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf0($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  --------------
  minimum time:     8.926 s (34.64% GC)
  median time:      8.926 s (34.64% GC)
  mean time:        8.926 s (34.64% GC)
  maximum time:     8.926 s (34.64% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf1($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  97.39 GiB
  allocs estimate:  18548950
  --------------
  minimum time:     95.674 s (28.28% GC)
  median time:      95.674 s (28.28% GC)
  mean time:        95.674 s (28.28% GC)
  maximum time:     95.674 s (28.28% GC)
  --------------
  samples:          1
  evals/sample:     1

Changing from views to getindexes in energy computation step only increased the total time to 110.982 seconds. Conclusion: Columnwise dot products (f1) both via view or getindex are much much (more than 10x) slower (reported here) than taking the diagonal of the result of multiplication of query and keys matrices (f0) when it comes to computing gradients.

AzamatB commented 4 years ago

Changing

ŷs = map(1:maxT) do _

to

ŷs = broadcast(1:maxT) do _

did not win any performance.

AzamatB commented 4 years ago

Changing αᵢs = softmax(hcat(Eᵢs...)') to

also with D × B × T ordering of input dimensions: