AzamatB / ListenAttendSpell.jl

Julia implementation of Listen, Attend and Spell model with Flux.jl
MIT License
1 stars 0 forks source link

@code_warntype outputs for functions participating in the forward pass #7

Open AzamatB opened 4 years ago

AzamatB commented 4 years ago
function flip(f, xs)
   rev_time = reverse(eachindex(xs))
   return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time)
   # the same as
   # flipped_xs = Buffer(xs)
   # @inbounds for t ∈ rev_time
   #    flipped_xs[t] = f(xs[t])
   # end
   # return copy(flipped_xs)
   # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end

returns Any

julia> @code_warntype flip(m, xs)
Variables
  #self#::Core.Compiler.Const(flip, false)
  f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
  xs::Array{Array{Float32,2},1}
  rev_time::StepRange{Int64,Int64}

Body::Any
1 ─ %1 = Main.eachindex(xs)::Base.OneTo{Int64}
│        (rev_time = Main.reverse(%1))
│   %3 = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}
│   %4 = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %5 = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│   %6 = Base.materialize(%5)::Any
│   %7 = Main.Ref(%6)::Base.RefValue{_A} where _A
│   %8 = Base.broadcasted(Main.getindex, %7, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),_A} where _A<:Tuple
│   %9 = Base.materialize(%8)::Any
└──      return %9

alleviated by

function flip(f, xs::T) where T
   rev_time = reverse(eachindex(xs))
   return getindex.(Ref(
      f.(getindex.(Ref(xs), rev_time))::T
   ), rev_time)
   # the same as
   # flipped_xs = Buffer(xs)
   # @inbounds for t ∈ rev_time
   #    flipped_xs[t] = f(xs[t])
   # end
   # return copy(flipped_xs)
   # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end
julia> @code_warntype flip(m, xs)
Variables
  #self#::Core.Compiler.Const(flip, false)
  f::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
  xs::Array{Array{Float32,2},1}
  rev_time::StepRange{Int64,Int64}

Body::Array{Array{Float32,2},1}
1 ─ %1  = Main.eachindex(xs)::Base.OneTo{Int64}
│         (rev_time = Main.reverse(%1))
│   %3  = Main.Ref(xs)::Base.RefValue{Array{Array{Float32,2},1}}
│   %4  = Base.broadcasted(Main.getindex, %3, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %5  = Base.broadcasted(f, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}}}
│   %6  = Base.materialize(%5)::Any
│   %7  = Core.typeassert(%6, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
│   %8  = Main.Ref(%7)::Base.RefValue{Array{Array{Float32,2},1}}
│   %9  = Base.broadcasted(Main.getindex, %8, rev_time)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(getindex),Tuple{Base.RefValue{Array{Array{Float32,2},1}},StepRange{Int64,Int64}}}
│   %10 = Base.materialize(%9)::Array{Array{Float32,2},1}
└──       return %10
AzamatB commented 4 years ago
function (m::BLSTM)(xs::AbstractVector{<:DenseVecOrMat})
   vcat.(m.forward.(xs), flip(m.backward, xs))
end

returns Any

julia> @code_warntype m(xs)
Variables
  m::BLSTM{Array{Float32,2},Array{Float32,1}}
  xs::Array{Array{Float32,2},1}

Body::Any
1 ─ %1 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %2 = Base.broadcasted(%1, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│   %3 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %4 = Main.flip(%3, xs)::Array{Array{Float32,2},1}
│   %5 = Base.broadcasted(Main.vcat, %2, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│   %6 = Base.materialize(%5)::Any
└──      return %6

alleviated by

function (m::BLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
   vcat.(m.forward.(xs), flip(m.backward, xs))::VM
end
julia> @code_warntype m(xs)
Variables
  m::BLSTM{Array{Float32,2},Array{Float32,1}}
  xs::Array{Array{Float32,2},1}

Body::Array{Array{Float32,2},1}
1 ─ %1 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %2 = Base.broadcasted(%1, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│   %3 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %4 = Main.flip(%3, xs)::Array{Array{Float32,2},1}
│   %5 = Base.broadcasted(Main.vcat, %2, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│   %6 = Base.materialize(%5)::Any
│   %7 = Core.typeassert(%6, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
└──      return %7
AzamatB commented 4 years ago
function (m::BLSTM)(Xs::DenseArray{<:Real,3})
   # preallocate output buffer
   Ys = Buffer(Xs, 2m.dim_out, size(Xs,2), size(Xs,3))
   axisYs₁ = axes(Ys, 1)
   time    = axes(Ys, 2)
   rev_time = reverse(time)
   @inbounds begin
      # get forward and backward slice indices
      slice_f = axisYs₁[1:m.dim_out]
      slice_b = axisYs₁[(m.dim_out+1):end]
      # bidirectional run step
      setindex!.(Ref(Ys),  m.forward.(view.(Ref(Xs), :, time, :)), Ref(slice_f), time, :)
      setindex!.(Ref(Ys), m.backward.(view.(Ref(Xs), :, rev_time, :)), Ref(slice_b), rev_time, :)
      # the same as
      # @views for (t_f, t_b) ∈ zip(time, rev_time)
      #    Ys[slice_f, t_f, :] =  m.forward(Xs[:, t_f, :])
      #    Ys[slice_b, t_b, :] = m.backward(Xs[:, t_b, :])
      # end
      # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
   end
   return copy(Ys)
end

type inference looks good

julia> @code_warntype m(Xs)
Variables
  m::BLSTM{Array{Float32,2},Array{Float32,1}}
  Xs::Array{Float32,3}
  val::Any
  Ys::Buffer{Float32,Array{Float32,3}}
  axisYs₁::Base.OneTo{Int64}
  time::Base.OneTo{Int64}
  rev_time::StepRange{Int64,Int64}
  slice_f::UnitRange{Int64}
  slice_b::UnitRange{Int64}

Body::Array{Float32,3}
1 ─ %1  = Base.getproperty(m, :dim_out)::Int64
│   %2  = (2 * %1)::Int64
│   %3  = Main.size(Xs, 2)::Int64
│   %4  = Main.size(Xs, 3)::Int64
│         (Ys = Main.Buffer(Xs, %2, %3, %4))
│         (axisYs₁ = Main.axes(Ys, 1))
│         (time = Main.axes(Ys, 2))
│         (rev_time = Main.reverse(time))
│         $(Expr(:inbounds, true))
│   %10 = axisYs₁::Base.OneTo{Int64}
│   %11 = Base.getproperty(m, :dim_out)::Int64
│   %12 = (1:%11)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│         (slice_f = Base.getindex(%10, %12))
│   %14 = axisYs₁::Base.OneTo{Int64}
│   %15 = Base.getproperty(m, :dim_out)::Int64
│   %16 = (%15 + 1)::Int64
│   %17 = Base.lastindex(axisYs₁)::Int64
│   %18 = (%16:%17)::UnitRange{Int64}
│         (slice_b = Base.getindex(%14, %18))
│   %20 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│   %21 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %22 = Main.Ref(Xs)::Base.RefValue{Array{Float32,3}}
│   %23 = time::Base.OneTo{Int64}
│   %24 = Base.broadcasted(Main.view, %22, Main.:(:), %23, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}
│   %25 = Base.broadcasted(%21, %24)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}}
│   %26 = Main.Ref(slice_f)::Base.RefValue{UnitRange{Int64}}
│   %27 = time::Base.OneTo{Int64}
│   %28 = Base.broadcasted(Main.setindex!, %20, %25, %26, %27, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},Base.OneTo{Int64},Base.RefValue{Colon}}}
│         Base.materialize(%28)
│   %30 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│   %31 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %32 = Main.Ref(Xs)::Base.RefValue{Array{Float32,3}}
│   %33 = rev_time::StepRange{Int64,Int64}
│   %34 = Base.broadcasted(Main.view, %32, Main.:(:), %33, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│   %35 = Base.broadcasted(%31, %34)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}}
│   %36 = Main.Ref(slice_b)::Base.RefValue{UnitRange{Int64}}
│   %37 = rev_time::StepRange{Int64,Int64}
│   %38 = Base.broadcasted(Main.setindex!, %30, %35, %36, %37, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│         (val = Base.materialize(%38))
│         $(Expr(:inbounds, :pop))
│         val
│   %42 = Main.copy(Ys)::Array{Float32,3}
└──       return %42
AzamatB commented 4 years ago
function (m::PBLSTM)(xs::AbstractVector{<:DenseVecOrMat})
   # reduce time duration by half by restacking consecutive pairs of input along the time dimension
   evenidxs = (firstindex(xs)+1):2:lastindex(xs)
   x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)
   # counterintuitively the gradient of the following version is not much faster (on par in fact),
   # even though it is implemented via broadcasting
   # x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
   # x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
   # x̄s = vcat.(xs[1:2:end], xs[2:2:end])
   # bidirectional run step
   return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))
end

returns Any

julia> @code_warntype m(xs)
Variables
  m::PBLSTM{Array{Float32,2},Array{Float32,1}}
  xs::Array{Array{Float32,2},1}
  #95::var"#95#96"{Array{Array{Float32,2},1}}
  x̄s::Array{Array{Float32,2},1}

Body::Any
1 ─ %1  = Main.:(var"#95#96")::Core.Compiler.Const(var"#95#96", false)
│   %2  = Core.typeof(xs)::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│   %3  = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#95#96"{Array{Array{Float32,2},1}}, false)
│         (#95 = %new(%3, xs))
│   %5  = #95::var"#95#96"{Array{Array{Float32,2},1}}
│   %6  = Main.lastindex(xs)::Int64
│   %7  = (2:2:%6)::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])
│   %8  = Base.Generator(%5, %7)::Core.Compiler.PartialStruct(Base.Generator{StepRange{Int64,Int64},var"#95#96"{Array{Array{Float32,2},1}}}, Any[var"#95#96"{Array{Array{Float32,2},1}}, Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])])
│         (x̄s = Base.collect(%8))
│   %10 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %11 = Base.broadcasted(%10, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│   %12 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %13 = Main.flip(%12, x̄s)::Array{Array{Float32,2},1}
│   %14 = Base.broadcasted(Main.vcat, %11, %13)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│   %15 = Base.materialize(%14)::Any
└──       return %15

alleviated by

function (m::PBLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
   # reduce time duration by half by restacking consecutive pairs of input along the time dimension
   evenidxs = (firstindex(xs)+1):2:lastindex(xs)
   x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)
   # counterintuitively the gradient of the following version is not much faster (on par in fact),
   # even though it is implemented via broadcasting
   # x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
   # x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
   # x̄s = vcat.(xs[1:2:end], xs[2:2:end])
   # bidirectional run step
   return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))::VM
end
julia> @code_warntype m(xs)
Variables
  m::PBLSTM{Array{Float32,2},Array{Float32,1}}
  xs::Array{Array{Float32,2},1}
  #154::var"#154#155"{Array{Array{Float32,2},1}}
  evenidxs::StepRange{Int64,Int64}
  x̄s::Array{Array{Float32,2},1}

Body::Array{Array{Float32,2},1}
1 ─ %1  = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│   %2  = (%1 + 1)::Core.Compiler.Const(2, false)
│   %3  = Main.lastindex(xs)::Int64
│         (evenidxs = %2:2:%3)
│   %5  = Main.:(var"#154#155")::Core.Compiler.Const(var"#154#155", false)
│   %6  = Core.typeof(xs)::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│   %7  = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#154#155"{Array{Array{Float32,2},1}}, false)
│         (#154 = %new(%7, xs))
│   %9  = #154::var"#154#155"{Array{Array{Float32,2},1}}
│   %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#154#155"{Array{Array{Float32,2},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#154#155"{Array{Array{Float32,2},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│         (x̄s = Base.materialize(%10))
│   %12 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %13 = Base.broadcasted(%12, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│   %14 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %15 = Main.flip(%14, x̄s)::Array{Array{Float32,2},1}
│   %16 = Base.broadcasted(Main.vcat, %13, %15)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│   %17 = Base.materialize(%16)::Any
│   %18 = Core.typeassert(%17, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
└──       return %18

on CuArrays the inference still fails:

Variables
  m::PBLSTM{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}
  xs::Array{CuArray{Float32,2,Nothing},1}
  #47::var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}
  evenidxs::StepRange{Int64,Int64}
  x̄s::Any

Body::Array{CuArray{Float32,2,Nothing},1}
1 ─ %1  = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│   %2  = (%1 + 1)::Core.Compiler.Const(2, false)
│   %3  = Main.lastindex(xs)::Int64
│         (evenidxs = %2:2:%3)
│   %5  = Main.:(var"#47#48")::Core.Compiler.Const(var"#47#48", false)
│   %6  = Core.typeof(xs)::Core.Compiler.Const(Array{CuArray{Float32,2,Nothing},1}, false)
│   %7  = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}, false)
│         (#47 = %new(%7, xs))
│   %9  = #47::var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}
│   %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#47#48"{Array{CuArray{Float32,2,Nothing},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│         (x̄s = Base.materialize(%10))
│   %12 = Base.getproperty(m, :forward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│   %13 = Base.broadcasted(%12, x̄s)::Any
│   %14 = Base.getproperty(m, :backward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│   %15 = Main.flip(%14, x̄s)::Any
│   %16 = Base.broadcasted(Main.vcat, %13, %15)::Any
│   %17 = Base.materialize(%16)::Any
│   %18 = Core.typeassert(%17, $(Expr(:static_parameter, 1)))::Array{CuArray{Float32,2,Nothing},1}
└──       return %18

alleviated by

function (m::PBLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
   # reduce time duration by half by restacking consecutive pairs of input along the time dimension
   evenidxs = (firstindex(xs)+1):2:lastindex(xs)
   x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)::VM
   # counterintuitively the gradient of the following version is not much faster (on par in fact),
   # even though it is implemented via broadcasting
   # x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
   # x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
   # x̄s = vcat.(xs[1:2:end], xs[2:2:end])
   # bidirectional run step
   return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))::VM
end

which produces

Variables
  m::PBLSTM{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}
  xs::Array{CuArray{Float32,2,Nothing},1}
  #93::var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}
  evenidxs::StepRange{Int64,Int64}
  x̄s::Array{CuArray{Float32,2,Nothing},1}

Body::Array{CuArray{Float32,2,Nothing},1}
1 ─ %1  = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│   %2  = (%1 + 1)::Core.Compiler.Const(2, false)
│   %3  = Main.lastindex(xs)::Int64
│         (evenidxs = %2:2:%3)
│   %5  = Main.:(var"#93#94")::Core.Compiler.Const(var"#93#94", false)
│   %6  = Core.typeof(xs)::Core.Compiler.Const(Array{CuArray{Float32,2,Nothing},1}, false)
│   %7  = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}, false)
│         (#93 = %new(%7, xs))
│   %9  = #93::var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}
│   %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#93#94"{Array{CuArray{Float32,2,Nothing},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│   %11 = Base.materialize(%10)::Any
│         (x̄s = Core.typeassert(%11, $(Expr(:static_parameter, 1))))
│   %13 = Base.getproperty(m, :forward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│   %14 = Base.broadcasted(%13, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}},Tuple{Array{CuArray{Float32,2,Nothing},1}}}
│   %15 = Base.getproperty(m, :backward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│   %16 = Main.flip(%15, x̄s)::Array{CuArray{Float32,2,Nothing},1}
│   %17 = Base.broadcasted(Main.vcat, %14, %16)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}},Tuple{Array{CuArray{Float32,2,Nothing},1}}},Array{CuArray{Float32,2,Nothing},1}}}
│   %18 = Base.materialize(%17)::Any
│   %19 = Core.typeassert(%18, $(Expr(:static_parameter, 1)))::Array{CuArray{Float32,2,Nothing},1}
└──       return %19
AzamatB commented 4 years ago
function (m::PBLSTM)(Xs::DenseArray{<:Real,3})
   D, T, B = size(Xs)
   T½ = T÷2
   # reduce time duration by half by restacking consecutive pairs of input along the time dimension
   X̄s = reshape(Xs, 2D, T½, B)
   # preallocate output buffer
   Ys = Buffer(Xs, 2m.dim_out, T½, B)
   axisYs₁ = axes(Ys, 1)
   time    = axes(Ys, 2)
   rev_time = reverse(time)
   # get forward and backward slice indices
   slice_f = axisYs₁[1:m.dim_out]
   slice_b = axisYs₁[(m.dim_out+1):end]
   # bidirectional run step
   setindex!.(Ref(Ys),  m.forward.(view.(Ref(X̄s), :, time, :)), Ref(slice_f), time, :)
   setindex!.(Ref(Ys), m.backward.(view.(Ref(X̄s), :, rev_time, :)), Ref(slice_b), rev_time, :)
   # the same as
   # @views for (t_f, t_b) ∈ zip(time, reverse(time))
   #    Ys[slice_f, t_f, :] =  m.forward(X̄s[:, t_f, :])
   #    Ys[slice_b, t_b, :] = m.backward(X̄s[:, t_b, :])
   # end
   # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
   return copy(Ys)
end

type inference looks good

julia> @code_warntype m(Xs)
Variables
  m::PBLSTM{Array{Float32,2},Array{Float32,1}}
  Xs::Array{Float32,3}
  D::Int64
  T::Int64
  @_5::Int64
  B::Int64
  T½::Int64
  X̄s::Array{Float32,3}
  Ys::Buffer{Float32,Array{Float32,3}}
  axisYs₁::Base.OneTo{Int64}
  time::Base.OneTo{Int64}
  rev_time::StepRange{Int64,Int64}
  slice_f::UnitRange{Int64}
  slice_b::UnitRange{Int64}

Body::Array{Float32,3}
1 ─ %1  = Main.size(Xs)::Tuple{Int64,Int64,Int64}
│   %2  = Base.indexed_iterate(%1, 1)::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(2, false)])
│         (D = Core.getfield(%2, 1))
│         (@_5 = Core.getfield(%2, 2))
│   %5  = Base.indexed_iterate(%1, 2, @_5::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(3, false)])
│         (T = Core.getfield(%5, 1))
│         (@_5 = Core.getfield(%5, 2))
│   %8  = Base.indexed_iterate(%1, 3, @_5::Core.Compiler.Const(3, false))::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(4, false)])
│         (B = Core.getfield(%8, 1))
│         (T½ = T ÷ 2)
│   %11 = (2 * D)::Int64
│   %12 = T½::Int64
│         (X̄s = Main.reshape(Xs, %11, %12, B))
│   %14 = Base.getproperty(m, :dim_out)::Int64
│   %15 = (2 * %14)::Int64
│   %16 = T½::Int64
│         (Ys = Main.Buffer(Xs, %15, %16, B))
│         (axisYs₁ = Main.axes(Ys, 1))
│         (time = Main.axes(Ys, 2))
│         (rev_time = Main.reverse(time))
│   %21 = axisYs₁::Base.OneTo{Int64}
│   %22 = Base.getproperty(m, :dim_out)::Int64
│   %23 = (1:%22)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│         (slice_f = Base.getindex(%21, %23))
│   %25 = axisYs₁::Base.OneTo{Int64}
│   %26 = Base.getproperty(m, :dim_out)::Int64
│   %27 = (%26 + 1)::Int64
│   %28 = Base.lastindex(axisYs₁)::Int64
│   %29 = (%27:%28)::UnitRange{Int64}
│         (slice_b = Base.getindex(%25, %29))
│   %31 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│   %32 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %33 = Main.Ref(X̄s)::Base.RefValue{Array{Float32,3}}
│   %34 = time::Base.OneTo{Int64}
│   %35 = Base.broadcasted(Main.view, %33, Main.:(:), %34, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}
│   %36 = Base.broadcasted(%32, %35)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}}
│   %37 = Main.Ref(slice_f)::Base.RefValue{UnitRange{Int64}}
│   %38 = time::Base.OneTo{Int64}
│   %39 = Base.broadcasted(Main.setindex!, %31, %36, %37, %38, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},Base.OneTo{Int64},Base.RefValue{Colon}}}
│         Base.materialize(%39)
│   %41 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│   %42 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│   %43 = Main.Ref(X̄s)::Base.RefValue{Array{Float32,3}}
│   %44 = rev_time::StepRange{Int64,Int64}
│   %45 = Base.broadcasted(Main.view, %43, Main.:(:), %44, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│   %46 = Base.broadcasted(%42, %45)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}}
│   %47 = Main.Ref(slice_b)::Base.RefValue{UnitRange{Int64}}
│   %48 = rev_time::StepRange{Int64,Int64}
│   %49 = Base.broadcasted(Main.setindex!, %41, %46, %47, %48, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│         Base.materialize(%49)
│   %51 = Main.copy(Ys)::Array{Float32,3}
└──       return %51
AzamatB commented 4 years ago
@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.attention_ψ(reshape(Hs, size(Hs,1), :)), size(m.attention_ψ.W, 1), :, batch_size)
   # ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   ŷs = map(1:maxT) do _
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.attention_ϕ(decoding)
      # compute energies via batch matrix multiplication
      @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      prediction = m.infer([decoding; context])
   end
   return ŷs
end

looks bad (bunch of Core.Boxes and return argument is failed to infer):

julia> @code_warntype decode(m, Hs, maxT)
Variables
  #self#::Core.Compiler.Const(decode, false)
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  Hs::Array{Float32,3}
  maxT::Int64
  #91::var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}
  batch_size::Int64
  context::Core.Box
  decoding::Core.Box
  prediction::Core.Box
  ψHs::Array{Float32,3}
  ŷs::Array{_A,1} where _A

Body::Array{_A,1} where _A
1 ─       nothing
│         (context = Core.Box())
│         (decoding = Core.Box())
│         (prediction = Core.Box())
│         (batch_size = Main.size(Hs, 3))
│   %6  = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %7  = Base.getproperty(%6, :context)::Array{Float32,2}
│   %8  = Main.repeat(%7, 1, batch_size)::Array{Float32,2}
│         Core.setfield!(context, :contents, %8)
│   %10 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %11 = Base.getproperty(%10, :decoding)::Array{Float32,2}
│   %12 = Main.repeat(%11, 1, batch_size)::Array{Float32,2}
│         Core.setfield!(decoding, :contents, %12)
│   %14 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %15 = Base.getproperty(%14, :prediction)::Array{Float32,2}
│   %16 = Main.repeat(%15, 1, batch_size)::Array{Float32,2}
│         Core.setfield!(prediction, :contents, %16)
│   %18 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│   %19 = Main.size(Hs, 1)::Int64
│   %20 = Main.reshape(Hs, %19, Main.:(:))::Array{Float32,2}
│   %21 = (%18)(%20)::Array{Float32,2}
│   %22 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│   %23 = Base.getproperty(%22, :W)::Array{Float32,2}
│   %24 = Main.size(%23, 1)::Int64
│         (ψHs = Main.reshape(%21, %24, Main.:(:), batch_size))
│   %26 = Main.:(var"#91#92")::Core.Compiler.Const(var"#91#92", false)
│   %27 = Core.typeof(m)::Core.Compiler.Const(LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}, false)
│   %28 = Core.typeof(Hs)::Core.Compiler.Const(Array{Float32,3}, false)
│   %29 = Core.typeof(ψHs)::Core.Compiler.Const(Array{Float32,3}, false)
│   %30 = Core.apply_type(%26, %27, %28, %29)::Core.Compiler.Const(var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}, false)
│   %31 = context::Core.Box
│   %32 = decoding::Core.Box
│   %33 = prediction::Core.Box
│         (#91 = %new(%30, m, Hs, %31, %32, %33, ψHs))
│   %35 = #91::var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}
│   %36 = (1:maxT)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│         (ŷs = Main.map(%35, %36))
└──       return ŷs

fixed by

@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
   batch_size = size(Hs, 3)
   # initialize state for every sequence in a batch
   context    = repeat(m.state₀.context,    1, batch_size)
   decoding   = repeat(m.state₀.decoding,   1, batch_size)
   prediction = repeat(m.state₀.prediction, 1, batch_size)
   # precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
   # passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
   ψHs = reshape(m.attention_ψ(reshape(Hs, size(Hs,1), :)), size(m.attention_ψ.W, 1), :, batch_size)
   # ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # check: all(ψhs .≈ eachslice(ψHs; dims=2))
   ŷs = Buffer(Vector{M}(undef, maxT), false)
   @inbounds for t ∈ eachindex(ŷs)
      # compute decoder state
      decoding = m.spell([decoding; prediction; context])::M
      # compute query ϕ(sᵢ)
      ϕsᵢ = m.attention_ϕ(decoding)
      # compute energies via batch matrix multiplication
      # @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
      Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
      # check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
      # compute attentions weights
      αᵢs = softmax(Eᵢs)
      # compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
      context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
      # check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
      # predict probability distribution over character alphabet
      ŷs[t] = prediction = m.infer([decoding; context])
   end
   return copy(ŷs)
end

which produces

julia> @code_warntype decode(m, Hs, maxT)
Variables
  #self#::Core.Compiler.Const(decode, false)
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  Hs::Array{Float32,3}
  maxT::Int64
  val::Nothing
  batch_size::Int64
  context::Array{Float32,2}
  decoding::Array{Float32,2}
  prediction::Array{Float32,2}
  ψHs::Array{Float32,3}
  ŷs::Buffer{Array{Float32,2},Array{Array{Float32,2},1}}
  @_12::Union{Nothing, Tuple{Int64,Int64}}
  t::Int64
  ϕsᵢ::Array{Float32,2}
  Eᵢs::Array{Float32,2}
  αᵢs::Array{Float32,2}

Body::Array{Array{Float32,2},1}
1 ─       nothing
│         Core.NewvarNode(:(val))
│         (batch_size = Main.size(Hs, 3))
│   %4  = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %5  = Base.getproperty(%4, :context)::Array{Float32,2}
│         (context = Main.repeat(%5, 1, batch_size))
│   %7  = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %8  = Base.getproperty(%7, :decoding)::Array{Float32,2}
│         (decoding = Main.repeat(%8, 1, batch_size))
│   %10 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│   %11 = Base.getproperty(%10, :prediction)::Array{Float32,2}
│         (prediction = Main.repeat(%11, 1, batch_size))
│   %13 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│   %14 = Main.size(Hs, 1)::Int64
│   %15 = Main.reshape(Hs, %14, Main.:(:))::Array{Float32,2}
│   %16 = (%13)(%15)::Array{Float32,2}
│   %17 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│   %18 = Base.getproperty(%17, :W)::Array{Float32,2}
│   %19 = Main.size(%18, 1)::Int64
│         (ψHs = Main.reshape(%16, %19, Main.:(:), batch_size))
│   %21 = Core.apply_type(Main.Vector, $(Expr(:static_parameter, 1)))::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│   %22 = (%21)(Main.undef, maxT)::Array{Array{Float32,2},1}
│         (ŷs = Main.Buffer(%22, false))
│         $(Expr(:inbounds, true))
│   %25 = Main.eachindex(ŷs)::Base.OneTo{Int64}
│         (@_12 = Base.iterate(%25))
│   %27 = (@_12 === nothing)::Bool
│   %28 = Base.not_int(%27)::Bool
└──       goto #4 if not %28
2 ┄ %30 = @_12::Tuple{Int64,Int64}::Tuple{Int64,Int64}
│         (t = Core.getfield(%30, 1))
│   %32 = Core.getfield(%30, 2)::Int64
│   %33 = Base.getproperty(m, :spell)::Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}}
│   %34 = Base.vcat(decoding, prediction, context)::Array{Float32,2}
│   %35 = (%33)(%34)::Any
│         (decoding = Core.typeassert(%35, $(Expr(:static_parameter, 1))))
│   %37 = Base.getproperty(m, :attention_ϕ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│         (ϕsᵢ = (%37)(decoding))
│   %39 = Core.tuple(1, 2)::Core.Compiler.Const((1, 2), false)
│   %40 = Core.tuple(1, 3, 2)::Core.Compiler.Const((1, 3, 2), false)
│   %41 = Core.tuple(%39, %40)::Core.Compiler.Const(((1, 2), (1, 3, 2)), false)
│   %42 = Core.tuple(3, 2)::Core.Compiler.Const((3, 2), false)
│   %43 = Core.apply_type(Main.EinCode, %41, %42)::Core.Compiler.Const(EinCode{((1, 2), (1, 3, 2)),(3, 2)}, false)
│   %44 = (%43)()::Core.Compiler.Const(EinCode{((1, 2), (1, 3, 2)),(3, 2)}(), false)
│   %45 = Core.tuple(ϕsᵢ, ψHs)::Tuple{Array{Float32,2},Array{Float32,3}}
│   %46 = Main.einsum(%44, %45)::Any
│         (Eᵢs = Core.typeassert(%46, $(Expr(:static_parameter, 1))))
│         (αᵢs = Main.softmax(Eᵢs))
│   %49 = Core.tuple(1, 2)::Core.Compiler.Const((1, 2), false)
│   %50 = Core.tuple(3, 1, 2)::Core.Compiler.Const((3, 1, 2), false)
│   %51 = Core.tuple(%49, %50)::Core.Compiler.Const(((1, 2), (3, 1, 2)), false)
│   %52 = Core.tuple(3, 2)::Core.Compiler.Const((3, 2), false)
│   %53 = Core.apply_type(Main.EinCode, %51, %52)::Core.Compiler.Const(EinCode{((1, 2), (3, 1, 2)),(3, 2)}, false)
│   %54 = (%53)()::Core.Compiler.Const(EinCode{((1, 2), (3, 1, 2)),(3, 2)}(), false)
│   %55 = Core.tuple(αᵢs, Hs)::Tuple{Array{Float32,2},Array{Float32,3}}
│   %56 = Main.einsum(%54, %55)::Any
│         (context = Core.typeassert(%56, $(Expr(:static_parameter, 1))))
│   %58 = Base.getproperty(m, :infer)::Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}
│   %59 = Base.vcat(decoding, context)::Array{Float32,2}
│   %60 = (%58)(%59)::Array{Float32,2}
│         (prediction = %60)
│         Base.setindex!(ŷs, %60, t)
│         (@_12 = Base.iterate(%25, %32))
│   %64 = (@_12 === nothing)::Bool
│   %65 = Base.not_int(%64)::Bool
└──       goto #4 if not %65
3 ─       goto #2
4 ┄       (val = nothing)
│         $(Expr(:inbounds, :pop))
│         val
│   %71 = Main.copy(ŷs)::Array{Array{Float32,2},1}
└──       return %71

the new version, does not result in performance regressions (in fact slightly faster as expected):

julia> reset!(m)

julia> @benchmark c0($m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  220.92 MiB
  allocs estimate:  174739
  --------------
  minimum time:     304.113 ms (7.01% GC)
  median time:      310.506 ms (7.23% GC)
  mean time:        313.922 ms (8.58% GC)
  maximum time:     329.119 ms (13.29% GC)
  --------------
  samples:          16
  evals/sample:     1

julia> reset!(m)

julia> @benchmark c1($m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  220.92 MiB
  allocs estimate:  174733
  --------------
  minimum time:     302.992 ms (6.77% GC)
  median time:      309.734 ms (7.81% GC)
  mean time:        313.066 ms (8.87% GC)
  maximum time:     340.022 ms (14.22% GC)
  --------------
  samples:          16
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gc0($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  926.07 MiB
  allocs estimate:  766920
  --------------
  minimum time:     1.075 s (9.87% GC)
  median time:      1.140 s (10.48% GC)
  mean time:        1.215 s (18.03% GC)
  maximum time:     1.420 s (29.09% GC)
  --------------
  samples:          5
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gc1($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial: 
  memory estimate:  923.93 MiB
  allocs estimate:  702862
  --------------
  minimum time:     1.014 s (10.56% GC)
  median time:      1.097 s (14.64% GC)
  mean time:        1.168 s (19.04% GC)
  maximum time:     1.514 s (37.04% GC)
  --------------
  samples:          5
  evals/sample:     1
AzamatB commented 4 years ago
function (m::LAS)(xs::AbstractVector{<:DenseMatrix}, maxT::Integer = length(xs))
   # compute input encoding, which are also values for the attention layer
   hs = m.listen(xs)
   dim_out, batch_size = size(first(hs))
   # transform T-length sequence of D×B matrices into the D×T×B tensor by first conconcatenating matrices
   # along the 1st dimension and to get singe DT×B matrix and then reshaping it into D×T×B tensor
   Hs = reshape(reduce(vcat, hs), dim_out, :, batch_size)
   # perform attend and spell steps
   ŷs = decode(m, Hs, maxT)
   return ŷs
end

infers completely under the fixed decode function

julia> @code_warntype m(xs)
Variables
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  xs::Array{Array{Float32,2},1}

Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.length(xs)::Int64
│   %2 = (m)(xs, %1)::Array{Array{Float32,2},1}
└──      return %2
AzamatB commented 4 years ago
function (m::LAS)(Xs::DenseArray{<:Real,3}, maxT::Integer = size(Xs,2))
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # perform attend and spell steps
   ŷs = decode(m, Hs, maxT)
   return ŷs
end

infers completely under the fixed decode function

julia> @code_warntype m(Xs)
Variables
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  Xs::Array{Float32,3}

Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.size(Xs, 2)::Int64
│   %2 = (m)(Xs, %1)::Array{Array{Float32,2},1}
└──      return %2
AzamatB commented 4 years ago
function pad(xs::DenseVector{<:DenseVector}, multiplicity::Integer)
   T = length(xs)
   newT = ceil(Int, T / multiplicity)multiplicity
   z = zero(first(xs))
   return [xs; (_ -> z).(1:(newT - T))]
end

infers completely

julia> @code_warntype pad(x, 7)
Variables
  #self#::Core.Compiler.Const(pad, false)
  xs::Array{Array{Float32,1},1}
  multiplicity::Int64
  #63::var"#63#64"{Array{Float32,1}}
  T::Int64
  newT::Int64
  z::Array{Float32,1}

Body::Array{Array{Float32,1},1}
1 ─       (T = Main.length(xs))
│   %2  = (T / multiplicity)::Float64
│   %3  = Main.ceil(Main.Int, %2)::Int64
│         (newT = %3 * multiplicity)
│   %5  = Main.first(xs)::Array{Float32,1}
│         (z = Main.zero(%5))
│   %7  = Main.:(var"#63#64")::Core.Compiler.Const(var"#63#64", false)
│   %8  = Core.typeof(z)::Core.Compiler.Const(Array{Float32,1}, false)
│   %9  = Core.apply_type(%7, %8)::Core.Compiler.Const(var"#63#64"{Array{Float32,1}}, false)
│         (#63 = %new(%9, z))
│   %11 = #63::var"#63#64"{Array{Float32,1}}
│   %12 = (newT - T)::Int64
│   %13 = (1:%12)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│   %14 = Base.broadcasted(%11, %13)::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#63#64"{Array{Float32,1}},Tuple{UnitRange{Int64}}}, Any[var"#63#64"{Array{Float32,1}}, Core.Compiler.PartialStruct(Tuple{UnitRange{Int64}}, Any[Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])]), Core.Compiler.Const(nothing, false)])
│   %15 = Base.materialize(%14)::Array{Array{Float32,1},1}
│   %16 = Base.vcat(xs, %15)::Array{Array{Float32,1},1}
└──       return %16
AzamatB commented 4 years ago
function (m::LAS)(x::AbstractVector{<:DenseVector})
   T = length(x)
   X = reshape(reduce(hcat, pad(x, time_squashing_factor(m))), Val(3))
   ŷs = dropdims.(m(X, T); dims=2)
   return ŷs
end

infers completely

julia> @code_warntype m(x)
Variables
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  x::Array{Array{Float32,1},1}
  T::Int64
  X::Array{Float32,3}
  ŷs::Array{Array{Float32,1},1}

Body::Array{Array{Float32,1},1}
1 ─       (T = Main.length(x))
│   %2  = Main.time_squashing_factor(m)::Int64
│   %3  = Main.pad(x, %2)::Array{Array{Float32,1},1}
│   %4  = Main.reduce(Main.hcat, %3)::Array{Float32,2}
│   %5  = Main.Val(3)::Core.Compiler.Const(Val{3}(), true)
│         (X = Main.reshape(%4, %5))
│   %7  = (m)(X, T)::Array{Array{Float32,2},1}
│   %8  = Base.broadcasted_kwsyntax::Core.Compiler.Const(Base.Broadcast.broadcasted_kwsyntax, false)
│   %9  = (:dims,)::Core.Compiler.Const((:dims,), false)
│   %10 = Core.apply_type(Core.NamedTuple, %9)::Core.Compiler.Const(NamedTuple{(:dims,),T} where T<:Tuple, false)
│   %11 = Core.tuple(2)::Core.Compiler.Const((2,), false)
│   %12 = (%10)(%11)::NamedTuple{(:dims,),Tuple{Int64}}
│   %13 = Core.kwfunc(%8)::Core.Compiler.Const(Base.Broadcast.var"#kw##broadcasted_kwsyntax"(), false)
│   %14 = (%13)(%12, %8, Main.dropdims, %7)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Base.Broadcast.var"#31#32"{Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:dims,),Tuple{Int64}}},typeof(dropdims)},Tuple{Array{Array{Float32,2},1}}}
│         (ŷs = Base.materialize(%14))
└──       return ŷs
AzamatB commented 4 years ago
function loss(m::LAS, X::DenseArray{<:Real,3}, linidxs::DenseVector{<:Integer}, maxT::Integer)
   Ŷs = m(X, maxT)
   l = -sum(Ŷs[linidxs])
   return l
end

infers completely

julia> @code_warntype loss(m, Xs, linidxs, maxT)
Variables
  #self#::Core.Compiler.Const(loss, false)
  m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
  X::Array{Float32,3}
  linidxs::Array{Int64,1}
  maxT::Int64
  Ŷs::Array{Float32,3}
  l::Float32

Body::Float32
1 ─      (Ŷs = (m)(X, maxT))
│   %2 = Base.getindex(Ŷs, linidxs)::Array{Float32,1}
│   %3 = Main.sum(%2)::Float32
│        (l = -%3)
└──      return l