FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
199 stars 121 forks source link

NaN with custom mask for MultiHeadAttention #572

Open mashu opened 3 months ago

mashu commented 3 months ago

Hi,

The background is that in Encoder-Decoder model used for translation from "Attention Is All You Need" I desired to mask-out the padding in sentence passed to Encoder's MultiHeadAttention, but I notice that the computed for the mask -neginf based on the logits eltype might cause some issues and lead to NaN.

Minimal example is provided here https://github.com/mashu/NaNTracker.jl The result is

caused by: DomainError with Float32[-7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; … ; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7;;; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; … ; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7;;; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; … ; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7;;; … ;;; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; … ; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7;;; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; … ; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7;;; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.38419f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; … ; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7]:
NaN on gradient input for layer: KeyPath(:mha, :out_proj)
Stacktrace:
  [1] (::Main.NaNTracker.var"#pb_check#2"{DebugWrapper{…}, Zygote.var"#ad_pullback#58"{…}})(Δ::Array{Float32, 3})
    @ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:26
  [2] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
  [3] #_#334
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:129 [inlined]
  [4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [5] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [7] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] #_#332
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] #_#5
    @ ~/NaNTracker.jl/src/Example.jl:24 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] EncoderOnly
    @ ~/NaNTracker.jl/src/Example.jl:22 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] #14
    @ ~/NaNTracker.jl/src/Example.jl:47 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [20] withgradient(f::Function, args::EncoderOnly)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:213
 [21] testit()
    @ Main ~/NaNTracker.jl/src/Example.jl:46
 [22] with_logging(::Function)
    @ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:36
 [23] top-level scope
    @ ~/NaNTracker.jl/src/Example.jl:50

I hope it's not an issue with my understanding of how mask should look like, but to be honest documentation in Flux could use a couple of examples for this particular use-case in addition to just make_causal_mask.

mashu commented 3 months ago

Just a thought, maybe these values need clamping in apply_attn_mask ?

mcabbott commented 3 months ago

I think the unusual thing here is that the mask (size 16×1×1×32) is constant for whole batches, and thus it's trying to set every value to -Inf before the softmax. That's not illegal according to help:

mask: Input array broadcastable to size (kv_len, q_len, nheads, batch_size). The mask
       is applied to the attention scores just before the softmax.

but it is unusual. Are you sure this is what you want, rather than e.g. running a smaller batch of randomly selected items?

Slightly shorter reproducer, and then a case with a more orthodox shape mask, I think:

julia> struct MaskAttention{A<:MultiHeadAttention, M<:AbstractArray}
           att::A
           mask::M
       end

julia> (m::MaskAttention)(x::AbstractArray) = first(m.att(x; m.mask))

julia> Flux.@layer :expand MaskAttention

julia> x = map(f->rand(Int32.(2:10), rand(8:16)), 1:32);

julia> x = reduce(hcat, rpad.(x, maximum(length.(x)), 1))
16×32 Matrix{Int32}:
  7   7   7   6  9   5   2   8   9   7   3   4  …   9   2   6   4   8   2   9   9  4   5   3   3
  2   5   3  10  7   7   9   7   9   6   2   3      7   7   7  10   8   7   7   8  7  10  10   7
  6   3   5   5  3   4   2   4   9   9   2   7      9   4   8   4   9   6   3   9  5   4   2   3

julia> mask = permutedims(repeat((x .== 1), outer = [1, 1, 1, 1]), (1, 4, 3, 2))
16×1×1×32 BitArray{4}:
[:, :, 1, 1] =
 0
 0
 0

julia> model = MaskAttention(MultiHeadAttention(16), mask)
MaskAttention(
  MultiHeadAttention(16; nheads=8),     # 1_024 parameters
  Bool[0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1;;;; … ;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1],  # 512 parameters
)                   # Total: 5 arrays, 1_536 parameters, 4.547 KiB.

julia> xx = randn32(16, 16, 32);

julia> model(xx) |> summary
"16×16×32 Array{Float32, 3}"

julia> model(xx) |> sum
NaN32

julia> findall(isnan, model(xx))
1280-element Vector{CartesianIndex{3}}:
 CartesianIndex(1, 1, 11)
 CartesianIndex(2, 1, 11)
 CartesianIndex(3, 1, 11)
 CartesianIndex(4, 1, 11)
 CartesianIndex(5, 1, 11)

julia> loss, grads = Flux.withgradient(model) do m
         sum(abs2, m(xx))
       end
(val = NaN32, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), k_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), v_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing)), mask = nothing),))

julia> mask2 = rand(Bool, 16, 16);

julia> model2 = MaskAttention(MultiHeadAttention(16), mask2);

julia> model2(xx) |> sum
15.281105f0

julia> loss, grads = Flux.withgradient(model2) do m
         sum(abs2, m(xx))
       end
(val = 2290.0107f0, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[-12.446103 -4.3954763 … -23.305235 3.0878963; 17.259354 -6.7767124 … 25.798717 6.8329597; … ; -22.414658 -17.097672 … 93.836235 -9.206725; 29.134241 13.785215 … -41.797527 -23.159046], bias = nothing, σ = nothing), k_proj = (weight = Float32[10.827733 18.880678 … -35.907875 23.034206; 8.75228 23.103594 … 13.421799 -14.956886; … ; -71.6004 -13.324369 … -35.224113 61.402447; 27.011333 124.142815 … -21.26666 -63.877186], bias = nothing, σ = nothing), v_proj = (weight = Float32[34.734734 -24.169163 … 102.391365 -69.705055; -1.5541999 37.55378 … -69.58273 39.782215; … ; 5.4440746 -176.59694 … 171.69466 161.58585; -139.37288 181.93517 … -213.38739 -58.174618], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[-18.867039 0.40507406 … 92.7197 -36.512943; -30.138624 -14.419439 … -92.25858 -63.47702; … ; -89.09866 92.45964 … -212.48007 164.08275; 35.601555 -31.360823 … 128.91348 -104.323494], bias = nothing, σ = nothing)), mask = nothing),))
mashu commented 3 months ago

So I do want to vary this mask per batch, because sequences that are recruited by sampling into the next batch vary in length and padding varies. This minimal example is just one batch to show the issue. I tried clamping before softmax and NaNs are gone. The idea is to mask out from attention in encoder the padding tokens. If it's unusual that am I doing something wrong ? I have three different kinds of masks: padding mask in encoder (this mwe), casual mask in decoder and padding mask in loss function which affects only target sequence in decoder.