Full experiment with ordering of dimensions in the forward pass for the input of type DenseArray{<:Real,3}

Implementations of 6 different versions of the forward pass, one for each of the permutation of D(input dimension), T(time duration) and B(batch size):

D × T × B
function fdtb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs
D × B × T
function fdbt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = permutedims(m.listen(Xs), [1,3,2])
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, :, axes(Hs,3)))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs
T × D × B
function ftdb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,1,3])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, :,1, batch_size) .* Hs; dims=1); dims=1)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs
B × T × D
function fbtd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,2,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=2); dims=2))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs
T × B × D
function ftbd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,3,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=1); dims=1))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs
B × D × T
function fbdt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,1,2])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(reshape(αᵢs, batch_size, 1, :) .* Hs; dims=3); dims=3))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs

function gfdtb(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdtb(m, Xs)))
function gfdbt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdbt(m, Xs)))
function gftdb(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftdb(m, Xs)))
function gfbtd(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbtd(m, Xs)))
function gftbd(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftbd(m, Xs)))
function gfbdt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbdt(m, Xs)))

Was used smallish size neural net with the following dimensions

encoder_dims = (
   blstm       = (in = 39, out = 64),
   pblstms_out = (64, 64, 64)
attention_dim = 128
decoder_out_dims = (128, 64)
m = LAS(encoder_dims, attention_dim, decoder_out_dims, out_dim)
θ = Flux.params(m)
using BenchmarkTools

Results for xs = last(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
  memory estimate:  2.94 GiB
  allocs estimate:  306948
  minimum time:     3.294 s (11.02% GC)
  median time:      3.295 s (11.20% GC)
  mean time:        3.295 s (11.20% GC)
  maximum time:     3.296 s (11.39% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
  memory estimate:  2.94 GiB
  allocs estimate:  318347
  minimum time:     3.485 s (10.91% GC)
  median time:      3.514 s (10.76% GC)
  mean time:        3.514 s (10.76% GC)
  maximum time:     3.543 s (10.61% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
  memory estimate:  3.54 GiB
  allocs estimate:  367033
  minimum time:     4.517 s (10.70% GC)
  median time:      4.523 s (10.57% GC)
  mean time:        4.523 s (10.57% GC)
  maximum time:     4.529 s (10.44% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
  memory estimate:  3.55 GiB
  allocs estimate:  386793
  minimum time:     4.498 s (10.76% GC)
  median time:      4.501 s (10.51% GC)
  mean time:        4.501 s (10.51% GC)
  maximum time:     4.504 s (10.26% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
  memory estimate:  3.55 GiB
  allocs estimate:  366273
  minimum time:     4.461 s (10.82% GC)
  median time:      4.469 s (10.95% GC)
  mean time:        4.469 s (10.95% GC)
  maximum time:     4.477 s (11.07% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
  memory estimate:  3.55 GiB
  allocs estimate:  387553
  minimum time:     5.083 s (9.19% GC)
  median time:      5.083 s (9.19% GC)
  mean time:        5.083 s (9.19% GC)
  maximum time:     5.083 s (9.19% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
  memory estimate:  17.50 GiB
  allocs estimate:  2662077
  minimum time:     30.478 s (64.75% GC)
  median time:      30.478 s (64.75% GC)
  mean time:        30.478 s (64.75% GC)
  maximum time:     30.478 s (64.75% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
  memory estimate:  17.51 GiB
  allocs estimate:  2689455
  minimum time:     30.562 s (64.84% GC)
  median time:      30.562 s (64.84% GC)
  mean time:        30.562 s (64.84% GC)
  maximum time:     30.562 s (64.84% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
  memory estimate:  18.29 GiB
  allocs estimate:  2968508
  minimum time:     28.648 s (57.85% GC)
  median time:      28.648 s (57.85% GC)
  mean time:        28.648 s (57.85% GC)
  maximum time:     28.648 s (57.85% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
  memory estimate:  18.32 GiB
  allocs estimate:  3001183
  minimum time:     28.857 s (57.49% GC)
  median time:      28.857 s (57.49% GC)
  mean time:        28.857 s (57.49% GC)
  maximum time:     28.857 s (57.49% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
  memory estimate:  18.32 GiB
  allocs estimate:  2964703
  minimum time:     28.671 s (57.99% GC)
  median time:      28.671 s (57.99% GC)
  mean time:        28.671 s (57.99% GC)
  maximum time:     28.671 s (57.99% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
  memory estimate:  18.32 GiB
  allocs estimate:  3008029
  minimum time:     28.963 s (57.72% GC)
  median time:      28.963 s (57.72% GC)
  mean time:        28.963 s (57.72% GC)
  maximum time:     28.963 s (57.72% GC)
  samples:          1
  evals/sample:     1

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  minimum time:     2.456 s (7.91% GC)
  median time:      2.509 s (9.16% GC)
  mean time:        2.504 s (8.79% GC)
  maximum time:     2.548 s (9.26% GC)
  samples:          3
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
  memory estimate:  1.72 GiB
  allocs estimate:  57409
  minimum time:     2.532 s (8.33% GC)
  median time:      2.604 s (8.87% GC)
  mean time:        2.604 s (8.87% GC)
  maximum time:     2.676 s (9.38% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
  memory estimate:  2.28 GiB
  allocs estimate:  74143
  minimum time:     3.719 s (7.99% GC)
  median time:      3.754 s (8.27% GC)
  mean time:        3.754 s (8.27% GC)
  maximum time:     3.789 s (8.54% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
  memory estimate:  2.29 GiB
  allocs estimate:  78511
  minimum time:     3.604 s (8.44% GC)
  median time:      3.623 s (8.55% GC)
  mean time:        3.623 s (8.55% GC)
  maximum time:     3.643 s (8.66% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
  memory estimate:  2.29 GiB
  allocs estimate:  73975
  minimum time:     3.684 s (8.50% GC)
  median time:      3.710 s (8.67% GC)
  mean time:        3.710 s (8.67% GC)
  maximum time:     3.736 s (8.84% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
  memory estimate:  2.29 GiB
  allocs estimate:  78679
  minimum time:     3.624 s (8.84% GC)
  median time:      3.636 s (8.83% GC)
  mean time:        3.636 s (8.83% GC)
  maximum time:     3.647 s (8.81% GC)
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  minimum time:     8.457 s (34.28% GC)
  median time:      8.457 s (34.28% GC)
  mean time:        8.457 s (34.28% GC)
  maximum time:     8.457 s (34.28% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
  memory estimate:  5.75 GiB
  allocs estimate:  368787
  minimum time:     8.571 s (33.81% GC)
  median time:      8.571 s (33.81% GC)
  mean time:        8.571 s (33.81% GC)
  maximum time:     8.571 s (33.81% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
  memory estimate:  6.48 GiB
  allocs estimate:  440883
  minimum time:     10.769 s (34.73% GC)
  median time:      10.769 s (34.73% GC)
  mean time:        10.769 s (34.73% GC)
  maximum time:     10.769 s (34.73% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
  memory estimate:  6.50 GiB
  allocs estimate:  448102
  minimum time:     10.603 s (35.30% GC)
  median time:      10.603 s (35.30% GC)
  mean time:        10.603 s (35.30% GC)
  maximum time:     10.603 s (35.30% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
  memory estimate:  6.50 GiB
  allocs estimate:  440038
  minimum time:     10.768 s (34.75% GC)
  median time:      10.768 s (34.75% GC)
  mean time:        10.768 s (34.75% GC)
  maximum time:     10.768 s (34.75% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
  memory estimate:  6.50 GiB
  allocs estimate:  449619
  minimum time:     10.641 s (35.38% GC)
  median time:      10.641 s (35.38% GC)
  mean time:        10.641 s (35.38% GC)
  maximum time:     10.641 s (35.38% GC)
  samples:          1
  evals/sample:     1

Conclusion: D × T × B and D × B × T orderings seem to be the most efficient ones, although not much difference between all of the versions and as T grows all computations are dominated by garbage collection and speed differences almost vanish.

Results for computing energy vector by taking the diagonal of the result of multiplication of query and keys matrices (f0) vs their columnwise dot products (f1):

function f0(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs

function f1(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   batch_axis = axes(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢ = view.(Ref(m.attention_ϕ(m.state.decoding)), :, batch_axis)
      # compute energies
      Eᵢs = (ψh -> ϕsᵢ .⋅ view.(Ref(ψh), :, batch_axis)).(ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   return ŷs

function gf0(m, Xs, θ)
   gradient(θ) do
      sum(sum(f0(m, Xs)))
function gf1(m, Xs, θ)
   gradient(θ) do
      sum(sum(f1(m, Xs)))

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m)

julia> @benchmark f0($m, $Xs)
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  minimum time:     2.461 s (7.70% GC)
  median time:      2.514 s (8.95% GC)
  mean time:        2.511 s (8.64% GC)
  maximum time:     2.556 s (9.25% GC)
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark f1($m, $Xs)
  memory estimate:  1.40 GiB
  allocs estimate:  643228
  minimum time:     2.222 s (7.58% GC)
  median time:      2.251 s (8.09% GC)
  mean time:        2.244 s (7.92% GC)
  maximum time:     2.261 s (8.09% GC)
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf0($m, $Xs, $θ)
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  minimum time:     8.926 s (34.64% GC)
  median time:      8.926 s (34.64% GC)
  mean time:        8.926 s (34.64% GC)
  maximum time:     8.926 s (34.64% GC)
  samples:          1
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf1($m, $Xs, $θ)
  memory estimate:  97.39 GiB
  allocs estimate:  18548950
  minimum time:     95.674 s (28.28% GC)
  median time:      95.674 s (28.28% GC)
  mean time:        95.674 s (28.28% GC)
  maximum time:     95.674 s (28.28% GC)
  samples:          1
  evals/sample:     1

Changing from views to getindexes in energy computation step only increased the total time to 110.982 seconds. Conclusion: Columnwise dot products (f1) both via view or getindex are much much (more than 10x) slower (reported here) than taking the diagonal of the result of multiplication of query and keys matrices (f0) when it comes to computing gradients.

ŷs = map(1:maxT) do _


ŷs = broadcast(1:maxT) do _

did not win any performance.

Changing αᵢs = softmax(hcat(Eᵢs...)') to

also with D × B × T ordering of input dimensions: