chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
526 stars 75 forks source link

Type unstable functions #8

Closed janfrancu closed 4 years ago

janfrancu commented 4 years ago

Forward step of Transformer is type unstable. Running the example from the docs

using Transfomers

m = Transformer(512, 8, 64, 2048) #define a Transformer block with 8 head and 64 neuron for each head
x = randn(512, 30, 3) #fake data of length 30

y = m(x)

and checking for @code_warntype produces:

julia> @code_warntype m(x,nothing)
Variables
  t::Transformer
  x::Array{Float64,3}
  mask::Core.Compiler.Const(nothing, false)
  a::Any
  insize::Any
  res_a::Any
  pwffn::AbstractArray{T,2} where T
  res_pwffn::Any

Body::Any
1 ─       Core.NewvarNode(:(insize))
│         Core.NewvarNode(:(pwffn))
│         Core.NewvarNode(:(res_pwffn))
│   %4  = (:mask,)::Core.Compiler.Const((:mask,), false)
│   %5  = Core.apply_type(Core.NamedTuple, %4)::Core.Compiler.Const(NamedTuple{(:mask,),T} where T<:Tuple, false)
│   %6  = Core.tuple(mask)::Core.Compiler.Const((nothing,), false)
│   %7  = (%5)(%6)::Core.Compiler.Const((mask = nothing,), false)
│   %8  = Base.getproperty(t, :mh)::Transformers.Basic.MultiheadAttention
│   %9  = Core.kwfunc(%8)::Core.Compiler.Const(Core.var"#Any##kw"(), false)
│   %10 = Base.getproperty(t, :mh)::Transformers.Basic.MultiheadAttention
│         (a = (%9)(%7, %10, x, x, x))
│   %12 = Base.getproperty(t, :drop)::Flux.Dropout
│         (a = (%12)(a))
│   %14 = Base.broadcasted(Transformers.Basic.:+, x, a)::Any
│         (res_a = Base.materialize(%14))
│   %16 = ($(Expr(:static_parameter, 2)) == 3)::Core.Compiler.Const(true, false)
└──       goto #3 if not %16
2 ─       (insize = Transformers.Basic.size(res_a))
│   %19 = res_a::Any
│   %20 = Base.getindex(insize, 1)::Any
└──       (res_a = Transformers.Basic.reshape(%19, %20, Transformers.Basic.:(:)))
3 ┄ %22 = Base.getproperty(t, :mhn)::Flux.LayerNorm
│         (res_a = (%22)(res_a))
│   %24 = Base.getproperty(t, :pw)::Transformers.Basic.PwFFN
│         (pwffn = (%24)(res_a))
│   %26 = Base.getproperty(t, :drop)::Flux.Dropout
│         (pwffn = (%26)(pwffn))
│   %28 = Base.broadcasted(Transformers.Basic.:+, res_a, pwffn)::Any
│         (res_pwffn = Base.materialize(%28))
│   %30 = Base.getproperty(t, :pwn)::Flux.LayerNorm
│         (res_pwffn = (%30)(res_pwffn))
│   %32 = ($(Expr(:static_parameter, 2)) == 3)::Core.Compiler.Const(true, false)
└──       goto #5 if not %32
4 ─ %34 = Core.tuple(res_pwffn, Transformers.Basic.:(:))::Core.Compiler.PartialStruct(Tuple{Any,Colon}, Any[Any, Core.Compiler.Const(Colon(), false)])
│   %35 = Base.tail::Core.Compiler.Const(Base.tail, false)
│   %36 = (%35)(insize)::Union{Tuple, NamedTuple}
└──       (res_pwffn = Core._apply_iterate(Base.iterate, Transformers.Basic.reshape, %34, %36))
5 ┄       return res_pwffn

The source of the unstabillity is probably the multihead attention, but I have not been able to distill it any further. I am using latest tagged version 0.1.3 of Transformers on Julia 1.4.1.

chengchingwen commented 4 years ago

I can reproduce this result on Julia 1.4.2 with the master branch. It does look like there are some problems with type inference for multihead attention. I will take some time to fix this.

Thanks for reporting it!

chengchingwen commented 4 years ago

Should be fixed in the new release (v0.1.7)