Open AzamatB opened 4 years ago
function (m::BLSTM)(xs::AbstractVector{<:DenseVecOrMat})
vcat.(m.forward.(xs), flip(m.backward, xs))
end
returns Any
julia> @code_warntype m(xs)
Variables
m::BLSTM{Array{Float32,2},Array{Float32,1}}
xs::Array{Array{Float32,2},1}
Body::Any
1 ─ %1 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %2 = Base.broadcasted(%1, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│ %3 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %4 = Main.flip(%3, xs)::Array{Array{Float32,2},1}
│ %5 = Base.broadcasted(Main.vcat, %2, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│ %6 = Base.materialize(%5)::Any
└── return %6
alleviated by
function (m::BLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
vcat.(m.forward.(xs), flip(m.backward, xs))::VM
end
julia> @code_warntype m(xs)
Variables
m::BLSTM{Array{Float32,2},Array{Float32,1}}
xs::Array{Array{Float32,2},1}
Body::Array{Array{Float32,2},1}
1 ─ %1 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %2 = Base.broadcasted(%1, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│ %3 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %4 = Main.flip(%3, xs)::Array{Array{Float32,2},1}
│ %5 = Base.broadcasted(Main.vcat, %2, %4)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│ %6 = Base.materialize(%5)::Any
│ %7 = Core.typeassert(%6, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
└── return %7
function (m::BLSTM)(Xs::DenseArray{<:Real,3})
# preallocate output buffer
Ys = Buffer(Xs, 2m.dim_out, size(Xs,2), size(Xs,3))
axisYs₁ = axes(Ys, 1)
time = axes(Ys, 2)
rev_time = reverse(time)
@inbounds begin
# get forward and backward slice indices
slice_f = axisYs₁[1:m.dim_out]
slice_b = axisYs₁[(m.dim_out+1):end]
# bidirectional run step
setindex!.(Ref(Ys), m.forward.(view.(Ref(Xs), :, time, :)), Ref(slice_f), time, :)
setindex!.(Ref(Ys), m.backward.(view.(Ref(Xs), :, rev_time, :)), Ref(slice_b), rev_time, :)
# the same as
# @views for (t_f, t_b) ∈ zip(time, rev_time)
# Ys[slice_f, t_f, :] = m.forward(Xs[:, t_f, :])
# Ys[slice_b, t_b, :] = m.backward(Xs[:, t_b, :])
# end
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end
return copy(Ys)
end
type inference looks good
julia> @code_warntype m(Xs)
Variables
m::BLSTM{Array{Float32,2},Array{Float32,1}}
Xs::Array{Float32,3}
val::Any
Ys::Buffer{Float32,Array{Float32,3}}
axisYs₁::Base.OneTo{Int64}
time::Base.OneTo{Int64}
rev_time::StepRange{Int64,Int64}
slice_f::UnitRange{Int64}
slice_b::UnitRange{Int64}
Body::Array{Float32,3}
1 ─ %1 = Base.getproperty(m, :dim_out)::Int64
│ %2 = (2 * %1)::Int64
│ %3 = Main.size(Xs, 2)::Int64
│ %4 = Main.size(Xs, 3)::Int64
│ (Ys = Main.Buffer(Xs, %2, %3, %4))
│ (axisYs₁ = Main.axes(Ys, 1))
│ (time = Main.axes(Ys, 2))
│ (rev_time = Main.reverse(time))
│ $(Expr(:inbounds, true))
│ %10 = axisYs₁::Base.OneTo{Int64}
│ %11 = Base.getproperty(m, :dim_out)::Int64
│ %12 = (1:%11)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│ (slice_f = Base.getindex(%10, %12))
│ %14 = axisYs₁::Base.OneTo{Int64}
│ %15 = Base.getproperty(m, :dim_out)::Int64
│ %16 = (%15 + 1)::Int64
│ %17 = Base.lastindex(axisYs₁)::Int64
│ %18 = (%16:%17)::UnitRange{Int64}
│ (slice_b = Base.getindex(%14, %18))
│ %20 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│ %21 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %22 = Main.Ref(Xs)::Base.RefValue{Array{Float32,3}}
│ %23 = time::Base.OneTo{Int64}
│ %24 = Base.broadcasted(Main.view, %22, Main.:(:), %23, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}
│ %25 = Base.broadcasted(%21, %24)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}}
│ %26 = Main.Ref(slice_f)::Base.RefValue{UnitRange{Int64}}
│ %27 = time::Base.OneTo{Int64}
│ %28 = Base.broadcasted(Main.setindex!, %20, %25, %26, %27, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},Base.OneTo{Int64},Base.RefValue{Colon}}}
│ Base.materialize(%28)
│ %30 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│ %31 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %32 = Main.Ref(Xs)::Base.RefValue{Array{Float32,3}}
│ %33 = rev_time::StepRange{Int64,Int64}
│ %34 = Base.broadcasted(Main.view, %32, Main.:(:), %33, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│ %35 = Base.broadcasted(%31, %34)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}}
│ %36 = Main.Ref(slice_b)::Base.RefValue{UnitRange{Int64}}
│ %37 = rev_time::StepRange{Int64,Int64}
│ %38 = Base.broadcasted(Main.setindex!, %30, %35, %36, %37, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│ (val = Base.materialize(%38))
│ $(Expr(:inbounds, :pop))
│ val
│ %42 = Main.copy(Ys)::Array{Float32,3}
└── return %42
function (m::PBLSTM)(xs::AbstractVector{<:DenseVecOrMat})
# reduce time duration by half by restacking consecutive pairs of input along the time dimension
evenidxs = (firstindex(xs)+1):2:lastindex(xs)
x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)
# counterintuitively the gradient of the following version is not much faster (on par in fact),
# even though it is implemented via broadcasting
# x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
# x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
# x̄s = vcat.(xs[1:2:end], xs[2:2:end])
# bidirectional run step
return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))
end
returns Any
julia> @code_warntype m(xs)
Variables
m::PBLSTM{Array{Float32,2},Array{Float32,1}}
xs::Array{Array{Float32,2},1}
#95::var"#95#96"{Array{Array{Float32,2},1}}
x̄s::Array{Array{Float32,2},1}
Body::Any
1 ─ %1 = Main.:(var"#95#96")::Core.Compiler.Const(var"#95#96", false)
│ %2 = Core.typeof(xs)::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│ %3 = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#95#96"{Array{Array{Float32,2},1}}, false)
│ (#95 = %new(%3, xs))
│ %5 = #95::var"#95#96"{Array{Array{Float32,2},1}}
│ %6 = Main.lastindex(xs)::Int64
│ %7 = (2:2:%6)::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])
│ %8 = Base.Generator(%5, %7)::Core.Compiler.PartialStruct(Base.Generator{StepRange{Int64,Int64},var"#95#96"{Array{Array{Float32,2},1}}}, Any[var"#95#96"{Array{Array{Float32,2},1}}, Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])])
│ (x̄s = Base.collect(%8))
│ %10 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %11 = Base.broadcasted(%10, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│ %12 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %13 = Main.flip(%12, x̄s)::Array{Array{Float32,2},1}
│ %14 = Base.broadcasted(Main.vcat, %11, %13)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│ %15 = Base.materialize(%14)::Any
└── return %15
alleviated by
function (m::PBLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
# reduce time duration by half by restacking consecutive pairs of input along the time dimension
evenidxs = (firstindex(xs)+1):2:lastindex(xs)
x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)
# counterintuitively the gradient of the following version is not much faster (on par in fact),
# even though it is implemented via broadcasting
# x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
# x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
# x̄s = vcat.(xs[1:2:end], xs[2:2:end])
# bidirectional run step
return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))::VM
end
julia> @code_warntype m(xs)
Variables
m::PBLSTM{Array{Float32,2},Array{Float32,1}}
xs::Array{Array{Float32,2},1}
#154::var"#154#155"{Array{Array{Float32,2},1}}
evenidxs::StepRange{Int64,Int64}
x̄s::Array{Array{Float32,2},1}
Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│ %2 = (%1 + 1)::Core.Compiler.Const(2, false)
│ %3 = Main.lastindex(xs)::Int64
│ (evenidxs = %2:2:%3)
│ %5 = Main.:(var"#154#155")::Core.Compiler.Const(var"#154#155", false)
│ %6 = Core.typeof(xs)::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│ %7 = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#154#155"{Array{Array{Float32,2},1}}, false)
│ (#154 = %new(%7, xs))
│ %9 = #154::var"#154#155"{Array{Array{Float32,2},1}}
│ %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#154#155"{Array{Array{Float32,2},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#154#155"{Array{Array{Float32,2},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│ (x̄s = Base.materialize(%10))
│ %12 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %13 = Base.broadcasted(%12, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│ %14 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %15 = Main.flip(%14, x̄s)::Array{Array{Float32,2},1}
│ %16 = Base.broadcasted(Main.vcat, %13, %15)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}},Array{Array{Float32,2},1}}}
│ %17 = Base.materialize(%16)::Any
│ %18 = Core.typeassert(%17, $(Expr(:static_parameter, 1)))::Array{Array{Float32,2},1}
└── return %18
on CuArray
s the inference still fails:
Variables
m::PBLSTM{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}
xs::Array{CuArray{Float32,2,Nothing},1}
#47::var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}
evenidxs::StepRange{Int64,Int64}
x̄s::Any
Body::Array{CuArray{Float32,2,Nothing},1}
1 ─ %1 = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│ %2 = (%1 + 1)::Core.Compiler.Const(2, false)
│ %3 = Main.lastindex(xs)::Int64
│ (evenidxs = %2:2:%3)
│ %5 = Main.:(var"#47#48")::Core.Compiler.Const(var"#47#48", false)
│ %6 = Core.typeof(xs)::Core.Compiler.Const(Array{CuArray{Float32,2,Nothing},1}, false)
│ %7 = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}, false)
│ (#47 = %new(%7, xs))
│ %9 = #47::var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}
│ %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#47#48"{Array{CuArray{Float32,2,Nothing},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#47#48"{Array{CuArray{Float32,2,Nothing},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│ (x̄s = Base.materialize(%10))
│ %12 = Base.getproperty(m, :forward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│ %13 = Base.broadcasted(%12, x̄s)::Any
│ %14 = Base.getproperty(m, :backward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│ %15 = Main.flip(%14, x̄s)::Any
│ %16 = Base.broadcasted(Main.vcat, %13, %15)::Any
│ %17 = Base.materialize(%16)::Any
│ %18 = Core.typeassert(%17, $(Expr(:static_parameter, 1)))::Array{CuArray{Float32,2,Nothing},1}
└── return %18
alleviated by
function (m::PBLSTM)(xs::VM) where VM <: AbstractVector{<:DenseVecOrMat}
# reduce time duration by half by restacking consecutive pairs of input along the time dimension
evenidxs = (firstindex(xs)+1):2:lastindex(xs)
x̄s = (i -> [xs[i-1]; xs[i]]).(evenidxs)::VM
# counterintuitively the gradient of the following version is not much faster (on par in fact),
# even though it is implemented via broadcasting
# x̄s = vcat.(getindex.(Ref(xs), 1:2:lastindex(xs)), getindex.(Ref(xs), 2:2:lastindex(xs)))
# x̄s = @views @inbounds(vcat.(xs[1:2:end], xs[2:2:end]))
# x̄s = vcat.(xs[1:2:end], xs[2:2:end])
# bidirectional run step
return vcat.(m.forward.(x̄s), flip(m.backward, x̄s))::VM
end
which produces
Variables
m::PBLSTM{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}
xs::Array{CuArray{Float32,2,Nothing},1}
#93::var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}
evenidxs::StepRange{Int64,Int64}
x̄s::Array{CuArray{Float32,2,Nothing},1}
Body::Array{CuArray{Float32,2,Nothing},1}
1 ─ %1 = Main.firstindex(xs)::Core.Compiler.Const(1, false)
│ %2 = (%1 + 1)::Core.Compiler.Const(2, false)
│ %3 = Main.lastindex(xs)::Int64
│ (evenidxs = %2:2:%3)
│ %5 = Main.:(var"#93#94")::Core.Compiler.Const(var"#93#94", false)
│ %6 = Core.typeof(xs)::Core.Compiler.Const(Array{CuArray{Float32,2,Nothing},1}, false)
│ %7 = Core.apply_type(%5, %6)::Core.Compiler.Const(var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}, false)
│ (#93 = %new(%7, xs))
│ %9 = #93::var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}
│ %10 = Base.broadcasted(%9, evenidxs::Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64]))::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#93#94"{Array{CuArray{Float32,2,Nothing},1}},Tuple{StepRange{Int64,Int64}}}, Any[var"#93#94"{Array{CuArray{Float32,2,Nothing},1}}, Core.Compiler.PartialStruct(Tuple{StepRange{Int64,Int64}}, Any[Core.Compiler.PartialStruct(StepRange{Int64,Int64}, Any[Core.Compiler.Const(2, false), Core.Compiler.Const(2, false), Int64])]), Core.Compiler.Const(nothing, false)])
│ %11 = Base.materialize(%10)::Any
│ (x̄s = Core.typeassert(%11, $(Expr(:static_parameter, 1))))
│ %13 = Base.getproperty(m, :forward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│ %14 = Base.broadcasted(%13, x̄s)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}},Tuple{Array{CuArray{Float32,2,Nothing},1}}}
│ %15 = Base.getproperty(m, :backward)::Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}}
│ %16 = Main.flip(%15, x̄s)::Array{CuArray{Float32,2,Nothing},1}
│ %17 = Base.broadcasted(Main.vcat, %14, %16)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(vcat),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}},Tuple{Array{CuArray{Float32,2,Nothing},1}}},Array{CuArray{Float32,2,Nothing},1}}}
│ %18 = Base.materialize(%17)::Any
│ %19 = Core.typeassert(%18, $(Expr(:static_parameter, 1)))::Array{CuArray{Float32,2,Nothing},1}
└── return %19
function (m::PBLSTM)(Xs::DenseArray{<:Real,3})
D, T, B = size(Xs)
T½ = T÷2
# reduce time duration by half by restacking consecutive pairs of input along the time dimension
X̄s = reshape(Xs, 2D, T½, B)
# preallocate output buffer
Ys = Buffer(Xs, 2m.dim_out, T½, B)
axisYs₁ = axes(Ys, 1)
time = axes(Ys, 2)
rev_time = reverse(time)
# get forward and backward slice indices
slice_f = axisYs₁[1:m.dim_out]
slice_b = axisYs₁[(m.dim_out+1):end]
# bidirectional run step
setindex!.(Ref(Ys), m.forward.(view.(Ref(X̄s), :, time, :)), Ref(slice_f), time, :)
setindex!.(Ref(Ys), m.backward.(view.(Ref(X̄s), :, rev_time, :)), Ref(slice_b), rev_time, :)
# the same as
# @views for (t_f, t_b) ∈ zip(time, reverse(time))
# Ys[slice_f, t_f, :] = m.forward(X̄s[:, t_f, :])
# Ys[slice_b, t_b, :] = m.backward(X̄s[:, t_b, :])
# end
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
return copy(Ys)
end
type inference looks good
julia> @code_warntype m(Xs)
Variables
m::PBLSTM{Array{Float32,2},Array{Float32,1}}
Xs::Array{Float32,3}
D::Int64
T::Int64
@_5::Int64
B::Int64
T½::Int64
X̄s::Array{Float32,3}
Ys::Buffer{Float32,Array{Float32,3}}
axisYs₁::Base.OneTo{Int64}
time::Base.OneTo{Int64}
rev_time::StepRange{Int64,Int64}
slice_f::UnitRange{Int64}
slice_b::UnitRange{Int64}
Body::Array{Float32,3}
1 ─ %1 = Main.size(Xs)::Tuple{Int64,Int64,Int64}
│ %2 = Base.indexed_iterate(%1, 1)::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(2, false)])
│ (D = Core.getfield(%2, 1))
│ (@_5 = Core.getfield(%2, 2))
│ %5 = Base.indexed_iterate(%1, 2, @_5::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(3, false)])
│ (T = Core.getfield(%5, 1))
│ (@_5 = Core.getfield(%5, 2))
│ %8 = Base.indexed_iterate(%1, 3, @_5::Core.Compiler.Const(3, false))::Core.Compiler.PartialStruct(Tuple{Int64,Int64}, Any[Int64, Core.Compiler.Const(4, false)])
│ (B = Core.getfield(%8, 1))
│ (T½ = T ÷ 2)
│ %11 = (2 * D)::Int64
│ %12 = T½::Int64
│ (X̄s = Main.reshape(Xs, %11, %12, B))
│ %14 = Base.getproperty(m, :dim_out)::Int64
│ %15 = (2 * %14)::Int64
│ %16 = T½::Int64
│ (Ys = Main.Buffer(Xs, %15, %16, B))
│ (axisYs₁ = Main.axes(Ys, 1))
│ (time = Main.axes(Ys, 2))
│ (rev_time = Main.reverse(time))
│ %21 = axisYs₁::Base.OneTo{Int64}
│ %22 = Base.getproperty(m, :dim_out)::Int64
│ %23 = (1:%22)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│ (slice_f = Base.getindex(%21, %23))
│ %25 = axisYs₁::Base.OneTo{Int64}
│ %26 = Base.getproperty(m, :dim_out)::Int64
│ %27 = (%26 + 1)::Int64
│ %28 = Base.lastindex(axisYs₁)::Int64
│ %29 = (%27:%28)::UnitRange{Int64}
│ (slice_b = Base.getindex(%25, %29))
│ %31 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│ %32 = Base.getproperty(m, :forward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %33 = Main.Ref(X̄s)::Base.RefValue{Array{Float32,3}}
│ %34 = time::Base.OneTo{Int64}
│ %35 = Base.broadcasted(Main.view, %33, Main.:(:), %34, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}
│ %36 = Base.broadcasted(%32, %35)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}}
│ %37 = Main.Ref(slice_f)::Base.RefValue{UnitRange{Int64}}
│ %38 = time::Base.OneTo{Int64}
│ %39 = Base.broadcasted(Main.setindex!, %31, %36, %37, %38, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},Base.OneTo{Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},Base.OneTo{Int64},Base.RefValue{Colon}}}
│ Base.materialize(%39)
│ %41 = Main.Ref(Ys)::Base.RefValue{Buffer{Float32,Array{Float32,3}}}
│ %42 = Base.getproperty(m, :backward)::Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}
│ %43 = Main.Ref(X̄s)::Base.RefValue{Array{Float32,3}}
│ %44 = rev_time::StepRange{Int64,Int64}
│ %45 = Base.broadcasted(Main.view, %43, Main.:(:), %44, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│ %46 = Base.broadcasted(%42, %45)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}}
│ %47 = Main.Ref(slice_b)::Base.RefValue{UnitRange{Int64}}
│ %48 = rev_time::StepRange{Int64,Int64}
│ %49 = Base.broadcasted(Main.setindex!, %41, %46, %47, %48, Main.:(:))::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(setindex!),Tuple{Base.RefValue{Buffer{Float32,Array{Float32,3}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(view),Tuple{Base.RefValue{Array{Float32,3}},Base.RefValue{Colon},StepRange{Int64,Int64},Base.RefValue{Colon}}}}},Base.RefValue{UnitRange{Int64}},StepRange{Int64,Int64},Base.RefValue{Colon}}}
│ Base.materialize(%49)
│ %51 = Main.copy(Ys)::Array{Float32,3}
└── return %51
@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
batch_size = size(Hs, 3)
# initialize state for every sequence in a batch
context = repeat(m.state₀.context, 1, batch_size)
decoding = repeat(m.state₀.decoding, 1, batch_size)
prediction = repeat(m.state₀.prediction, 1, batch_size)
# precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
# passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
ψHs = reshape(m.attention_ψ(reshape(Hs, size(Hs,1), :)), size(m.attention_ψ.W, 1), :, batch_size)
# ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# check: all(ψhs .≈ eachslice(ψHs; dims=2))
ŷs = map(1:maxT) do _
# compute decoder state
decoding = m.spell([decoding; prediction; context])
# compute query ϕ(sᵢ)
ϕsᵢ = m.attention_ϕ(decoding)
# compute energies via batch matrix multiplication
@ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
# check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
# compute attentions weights
αᵢs = softmax(Eᵢs)
# compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
@ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
# check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
# predict probability distribution over character alphabet
prediction = m.infer([decoding; context])
end
return ŷs
end
looks bad (bunch of Core.Box
es and return argument is failed to infer):
julia> @code_warntype decode(m, Hs, maxT)
Variables
#self#::Core.Compiler.Const(decode, false)
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
Hs::Array{Float32,3}
maxT::Int64
#91::var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}
batch_size::Int64
context::Core.Box
decoding::Core.Box
prediction::Core.Box
ψHs::Array{Float32,3}
ŷs::Array{_A,1} where _A
Body::Array{_A,1} where _A
1 ─ nothing
│ (context = Core.Box())
│ (decoding = Core.Box())
│ (prediction = Core.Box())
│ (batch_size = Main.size(Hs, 3))
│ %6 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %7 = Base.getproperty(%6, :context)::Array{Float32,2}
│ %8 = Main.repeat(%7, 1, batch_size)::Array{Float32,2}
│ Core.setfield!(context, :contents, %8)
│ %10 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %11 = Base.getproperty(%10, :decoding)::Array{Float32,2}
│ %12 = Main.repeat(%11, 1, batch_size)::Array{Float32,2}
│ Core.setfield!(decoding, :contents, %12)
│ %14 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %15 = Base.getproperty(%14, :prediction)::Array{Float32,2}
│ %16 = Main.repeat(%15, 1, batch_size)::Array{Float32,2}
│ Core.setfield!(prediction, :contents, %16)
│ %18 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│ %19 = Main.size(Hs, 1)::Int64
│ %20 = Main.reshape(Hs, %19, Main.:(:))::Array{Float32,2}
│ %21 = (%18)(%20)::Array{Float32,2}
│ %22 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│ %23 = Base.getproperty(%22, :W)::Array{Float32,2}
│ %24 = Main.size(%23, 1)::Int64
│ (ψHs = Main.reshape(%21, %24, Main.:(:), batch_size))
│ %26 = Main.:(var"#91#92")::Core.Compiler.Const(var"#91#92", false)
│ %27 = Core.typeof(m)::Core.Compiler.Const(LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}, false)
│ %28 = Core.typeof(Hs)::Core.Compiler.Const(Array{Float32,3}, false)
│ %29 = Core.typeof(ψHs)::Core.Compiler.Const(Array{Float32,3}, false)
│ %30 = Core.apply_type(%26, %27, %28, %29)::Core.Compiler.Const(var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}, false)
│ %31 = context::Core.Box
│ %32 = decoding::Core.Box
│ %33 = prediction::Core.Box
│ (#91 = %new(%30, m, Hs, %31, %32, %33, ψHs))
│ %35 = #91::var"#91#92"{LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}},Array{Float32,3},Array{Float32,3}}
│ %36 = (1:maxT)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│ (ŷs = Main.map(%35, %36))
└── return ŷs
fixed by
@inline function decode(m::LAS{M}, Hs::DenseArray{<:Real,3}, maxT::Integer) where M <: DenseMatrix
batch_size = size(Hs, 3)
# initialize state for every sequence in a batch
context = repeat(m.state₀.context, 1, batch_size)
decoding = repeat(m.state₀.decoding, 1, batch_size)
prediction = repeat(m.state₀.prediction, 1, batch_size)
# precompute keys ψ(H) by gluing the slices of Hs along the batch dimension into a single D×TB matrix, then
# passing it through the ψ dense layer in a single pass and then reshaping the result back into D′×T×B tensor
ψHs = reshape(m.attention_ψ(reshape(Hs, size(Hs,1), :)), size(m.attention_ψ.W, 1), :, batch_size)
# ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# check: all(ψhs .≈ eachslice(ψHs; dims=2))
ŷs = Buffer(Vector{M}(undef, maxT), false)
@inbounds for t ∈ eachindex(ŷs)
# compute decoder state
decoding = m.spell([decoding; prediction; context])::M
# compute query ϕ(sᵢ)
ϕsᵢ = m.attention_ϕ(decoding)
# compute energies via batch matrix multiplication
# @ein Eᵢs[t,b] := ϕsᵢ[d,b] * ψHs[d,t,b]
Eᵢs = einsum(EinCode{((1,2), (1,3,2)), (3,2)}(), (ϕsᵢ, ψHs))::M
# check: Eᵢs ≈ reduce(hcat, diag.((ϕsᵢ',) .* ψhs))'
# compute attentions weights
αᵢs = softmax(Eᵢs)
# compute attended context using Einstein summation convention, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# @ein context[d,b] := αᵢs[t,b] * Hs[d,t,b]
context = einsum(EinCode{((1,2), (3,1,2)), (3,2)}(), (αᵢs, Hs))::M
# check: context ≈ reduce(hcat, [sum(αᵢs[t,b] *Hs[:,t,b] for t ∈ axes(αᵢs, 1)) for b ∈ axes(αᵢs,2)])
# predict probability distribution over character alphabet
ŷs[t] = prediction = m.infer([decoding; context])
end
return copy(ŷs)
end
which produces
julia> @code_warntype decode(m, Hs, maxT)
Variables
#self#::Core.Compiler.Const(decode, false)
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
Hs::Array{Float32,3}
maxT::Int64
val::Nothing
batch_size::Int64
context::Array{Float32,2}
decoding::Array{Float32,2}
prediction::Array{Float32,2}
ψHs::Array{Float32,3}
ŷs::Buffer{Array{Float32,2},Array{Array{Float32,2},1}}
@_12::Union{Nothing, Tuple{Int64,Int64}}
t::Int64
ϕsᵢ::Array{Float32,2}
Eᵢs::Array{Float32,2}
αᵢs::Array{Float32,2}
Body::Array{Array{Float32,2},1}
1 ─ nothing
│ Core.NewvarNode(:(val))
│ (batch_size = Main.size(Hs, 3))
│ %4 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %5 = Base.getproperty(%4, :context)::Array{Float32,2}
│ (context = Main.repeat(%5, 1, batch_size))
│ %7 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %8 = Base.getproperty(%7, :decoding)::Array{Float32,2}
│ (decoding = Main.repeat(%8, 1, batch_size))
│ %10 = Base.getproperty(m, :state₀)::State₀{Array{Float32,2}}
│ %11 = Base.getproperty(%10, :prediction)::Array{Float32,2}
│ (prediction = Main.repeat(%11, 1, batch_size))
│ %13 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│ %14 = Main.size(Hs, 1)::Int64
│ %15 = Main.reshape(Hs, %14, Main.:(:))::Array{Float32,2}
│ %16 = (%13)(%15)::Array{Float32,2}
│ %17 = Base.getproperty(m, :attention_ψ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│ %18 = Base.getproperty(%17, :W)::Array{Float32,2}
│ %19 = Main.size(%18, 1)::Int64
│ (ψHs = Main.reshape(%16, %19, Main.:(:), batch_size))
│ %21 = Core.apply_type(Main.Vector, $(Expr(:static_parameter, 1)))::Core.Compiler.Const(Array{Array{Float32,2},1}, false)
│ %22 = (%21)(Main.undef, maxT)::Array{Array{Float32,2},1}
│ (ŷs = Main.Buffer(%22, false))
│ $(Expr(:inbounds, true))
│ %25 = Main.eachindex(ŷs)::Base.OneTo{Int64}
│ (@_12 = Base.iterate(%25))
│ %27 = (@_12 === nothing)::Bool
│ %28 = Base.not_int(%27)::Bool
└── goto #4 if not %28
2 ┄ %30 = @_12::Tuple{Int64,Int64}::Tuple{Int64,Int64}
│ (t = Core.getfield(%30, 1))
│ %32 = Core.getfield(%30, 2)::Int64
│ %33 = Base.getproperty(m, :spell)::Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}}
│ %34 = Base.vcat(decoding, prediction, context)::Array{Float32,2}
│ %35 = (%33)(%34)::Any
│ (decoding = Core.typeassert(%35, $(Expr(:static_parameter, 1))))
│ %37 = Base.getproperty(m, :attention_ϕ)::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}
│ (ϕsᵢ = (%37)(decoding))
│ %39 = Core.tuple(1, 2)::Core.Compiler.Const((1, 2), false)
│ %40 = Core.tuple(1, 3, 2)::Core.Compiler.Const((1, 3, 2), false)
│ %41 = Core.tuple(%39, %40)::Core.Compiler.Const(((1, 2), (1, 3, 2)), false)
│ %42 = Core.tuple(3, 2)::Core.Compiler.Const((3, 2), false)
│ %43 = Core.apply_type(Main.EinCode, %41, %42)::Core.Compiler.Const(EinCode{((1, 2), (1, 3, 2)),(3, 2)}, false)
│ %44 = (%43)()::Core.Compiler.Const(EinCode{((1, 2), (1, 3, 2)),(3, 2)}(), false)
│ %45 = Core.tuple(ϕsᵢ, ψHs)::Tuple{Array{Float32,2},Array{Float32,3}}
│ %46 = Main.einsum(%44, %45)::Any
│ (Eᵢs = Core.typeassert(%46, $(Expr(:static_parameter, 1))))
│ (αᵢs = Main.softmax(Eᵢs))
│ %49 = Core.tuple(1, 2)::Core.Compiler.Const((1, 2), false)
│ %50 = Core.tuple(3, 1, 2)::Core.Compiler.Const((3, 1, 2), false)
│ %51 = Core.tuple(%49, %50)::Core.Compiler.Const(((1, 2), (3, 1, 2)), false)
│ %52 = Core.tuple(3, 2)::Core.Compiler.Const((3, 2), false)
│ %53 = Core.apply_type(Main.EinCode, %51, %52)::Core.Compiler.Const(EinCode{((1, 2), (3, 1, 2)),(3, 2)}, false)
│ %54 = (%53)()::Core.Compiler.Const(EinCode{((1, 2), (3, 1, 2)),(3, 2)}(), false)
│ %55 = Core.tuple(αᵢs, Hs)::Tuple{Array{Float32,2},Array{Float32,3}}
│ %56 = Main.einsum(%54, %55)::Any
│ (context = Core.typeassert(%56, $(Expr(:static_parameter, 1))))
│ %58 = Base.getproperty(m, :infer)::Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}
│ %59 = Base.vcat(decoding, context)::Array{Float32,2}
│ %60 = (%58)(%59)::Array{Float32,2}
│ (prediction = %60)
│ Base.setindex!(ŷs, %60, t)
│ (@_12 = Base.iterate(%25, %32))
│ %64 = (@_12 === nothing)::Bool
│ %65 = Base.not_int(%64)::Bool
└── goto #4 if not %65
3 ─ goto #2
4 ┄ (val = nothing)
│ $(Expr(:inbounds, :pop))
│ val
│ %71 = Main.copy(ŷs)::Array{Array{Float32,2},1}
└── return %71
the new version, does not result in performance regressions (in fact slightly faster as expected):
julia> reset!(m)
julia> @benchmark c0($m, $Hs, $maxT)
BenchmarkTools.Trial:
memory estimate: 220.92 MiB
allocs estimate: 174739
--------------
minimum time: 304.113 ms (7.01% GC)
median time: 310.506 ms (7.23% GC)
mean time: 313.922 ms (8.58% GC)
maximum time: 329.119 ms (13.29% GC)
--------------
samples: 16
evals/sample: 1
julia> reset!(m)
julia> @benchmark c1($m, $Hs, $maxT)
BenchmarkTools.Trial:
memory estimate: 220.92 MiB
allocs estimate: 174733
--------------
minimum time: 302.992 ms (6.77% GC)
median time: 309.734 ms (7.81% GC)
mean time: 313.066 ms (8.87% GC)
maximum time: 340.022 ms (14.22% GC)
--------------
samples: 16
evals/sample: 1
julia> reset!(m)
julia> @benchmark gc0($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial:
memory estimate: 926.07 MiB
allocs estimate: 766920
--------------
minimum time: 1.075 s (9.87% GC)
median time: 1.140 s (10.48% GC)
mean time: 1.215 s (18.03% GC)
maximum time: 1.420 s (29.09% GC)
--------------
samples: 5
evals/sample: 1
julia> reset!(m)
julia> @benchmark gc1($θ, $m, $Hs, $maxT)
BenchmarkTools.Trial:
memory estimate: 923.93 MiB
allocs estimate: 702862
--------------
minimum time: 1.014 s (10.56% GC)
median time: 1.097 s (14.64% GC)
mean time: 1.168 s (19.04% GC)
maximum time: 1.514 s (37.04% GC)
--------------
samples: 5
evals/sample: 1
function (m::LAS)(xs::AbstractVector{<:DenseMatrix}, maxT::Integer = length(xs))
# compute input encoding, which are also values for the attention layer
hs = m.listen(xs)
dim_out, batch_size = size(first(hs))
# transform T-length sequence of D×B matrices into the D×T×B tensor by first conconcatenating matrices
# along the 1st dimension and to get singe DT×B matrix and then reshaping it into D×T×B tensor
Hs = reshape(reduce(vcat, hs), dim_out, :, batch_size)
# perform attend and spell steps
ŷs = decode(m, Hs, maxT)
return ŷs
end
infers completely under the fixed decode
function
julia> @code_warntype m(xs)
Variables
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
xs::Array{Array{Float32,2},1}
Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.length(xs)::Int64
│ %2 = (m)(xs, %1)::Array{Array{Float32,2},1}
└── return %2
function (m::LAS)(Xs::DenseArray{<:Real,3}, maxT::Integer = size(Xs,2))
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# perform attend and spell steps
ŷs = decode(m, Hs, maxT)
return ŷs
end
infers completely under the fixed decode function
julia> @code_warntype m(Xs)
Variables
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
Xs::Array{Float32,3}
Body::Array{Array{Float32,2},1}
1 ─ %1 = Main.size(Xs, 2)::Int64
│ %2 = (m)(Xs, %1)::Array{Array{Float32,2},1}
└── return %2
function pad(xs::DenseVector{<:DenseVector}, multiplicity::Integer)
T = length(xs)
newT = ceil(Int, T / multiplicity)multiplicity
z = zero(first(xs))
return [xs; (_ -> z).(1:(newT - T))]
end
infers completely
julia> @code_warntype pad(x, 7)
Variables
#self#::Core.Compiler.Const(pad, false)
xs::Array{Array{Float32,1},1}
multiplicity::Int64
#63::var"#63#64"{Array{Float32,1}}
T::Int64
newT::Int64
z::Array{Float32,1}
Body::Array{Array{Float32,1},1}
1 ─ (T = Main.length(xs))
│ %2 = (T / multiplicity)::Float64
│ %3 = Main.ceil(Main.Int, %2)::Int64
│ (newT = %3 * multiplicity)
│ %5 = Main.first(xs)::Array{Float32,1}
│ (z = Main.zero(%5))
│ %7 = Main.:(var"#63#64")::Core.Compiler.Const(var"#63#64", false)
│ %8 = Core.typeof(z)::Core.Compiler.Const(Array{Float32,1}, false)
│ %9 = Core.apply_type(%7, %8)::Core.Compiler.Const(var"#63#64"{Array{Float32,1}}, false)
│ (#63 = %new(%9, z))
│ %11 = #63::var"#63#64"{Array{Float32,1}}
│ %12 = (newT - T)::Int64
│ %13 = (1:%12)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│ %14 = Base.broadcasted(%11, %13)::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,var"#63#64"{Array{Float32,1}},Tuple{UnitRange{Int64}}}, Any[var"#63#64"{Array{Float32,1}}, Core.Compiler.PartialStruct(Tuple{UnitRange{Int64}}, Any[Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])]), Core.Compiler.Const(nothing, false)])
│ %15 = Base.materialize(%14)::Array{Array{Float32,1},1}
│ %16 = Base.vcat(xs, %15)::Array{Array{Float32,1},1}
└── return %16
function (m::LAS)(x::AbstractVector{<:DenseVector})
T = length(x)
X = reshape(reduce(hcat, pad(x, time_squashing_factor(m))), Val(3))
ŷs = dropdims.(m(X, T); dims=2)
return ŷs
end
infers completely
julia> @code_warntype m(x)
Variables
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
x::Array{Array{Float32,1},1}
T::Int64
X::Array{Float32,3}
ŷs::Array{Array{Float32,1},1}
Body::Array{Array{Float32,1},1}
1 ─ (T = Main.length(x))
│ %2 = Main.time_squashing_factor(m)::Int64
│ %3 = Main.pad(x, %2)::Array{Array{Float32,1},1}
│ %4 = Main.reduce(Main.hcat, %3)::Array{Float32,2}
│ %5 = Main.Val(3)::Core.Compiler.Const(Val{3}(), true)
│ (X = Main.reshape(%4, %5))
│ %7 = (m)(X, T)::Array{Array{Float32,2},1}
│ %8 = Base.broadcasted_kwsyntax::Core.Compiler.Const(Base.Broadcast.broadcasted_kwsyntax, false)
│ %9 = (:dims,)::Core.Compiler.Const((:dims,), false)
│ %10 = Core.apply_type(Core.NamedTuple, %9)::Core.Compiler.Const(NamedTuple{(:dims,),T} where T<:Tuple, false)
│ %11 = Core.tuple(2)::Core.Compiler.Const((2,), false)
│ %12 = (%10)(%11)::NamedTuple{(:dims,),Tuple{Int64}}
│ %13 = Core.kwfunc(%8)::Core.Compiler.Const(Base.Broadcast.var"#kw##broadcasted_kwsyntax"(), false)
│ %14 = (%13)(%12, %8, Main.dropdims, %7)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Base.Broadcast.var"#31#32"{Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:dims,),Tuple{Int64}}},typeof(dropdims)},Tuple{Array{Array{Float32,2},1}}}
│ (ŷs = Base.materialize(%14))
└── return ŷs
function loss(m::LAS, X::DenseArray{<:Real,3}, linidxs::DenseVector{<:Integer}, maxT::Integer)
Ŷs = m(X, maxT)
l = -sum(Ŷs[linidxs])
return l
end
infers completely
julia> @code_warntype loss(m, Xs, linidxs, maxT)
Variables
#self#::Core.Compiler.Const(loss, false)
m::LAS{Array{Float32,2},Chain{Tuple{BLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}},PBLSTM{Array{Float32,2},Array{Float32,1}}}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Chain{Tuple{Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}},Recur{LSTMCell{Array{Float32,2},Array{Float32,1}}}}},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},typeof(logsoftmax)}}}
X::Array{Float32,3}
linidxs::Array{Int64,1}
maxT::Int64
Ŷs::Array{Float32,3}
l::Float32
Body::Float32
1 ─ (Ŷs = (m)(X, maxT))
│ %2 = Base.getindex(Ŷs, linidxs)::Array{Float32,1}
│ %3 = Main.sum(%2)::Float32
│ (l = -%3)
└── return l
returns
Any
alleviated by