Open mashu opened 3 months ago
Just a thought, maybe these values need clamping in apply_attn_mask
?
I think the unusual thing here is that the mask (size 16×1×1×32) is constant for whole batches, and thus it's trying to set every value to -Inf before the softmax. That's not illegal according to help:
mask: Input array broadcastable to size (kv_len, q_len, nheads, batch_size). The mask
is applied to the attention scores just before the softmax.
but it is unusual. Are you sure this is what you want, rather than e.g. running a smaller batch of randomly selected items?
Slightly shorter reproducer, and then a case with a more orthodox shape mask, I think:
julia> struct MaskAttention{A<:MultiHeadAttention, M<:AbstractArray}
att::A
mask::M
end
julia> (m::MaskAttention)(x::AbstractArray) = first(m.att(x; m.mask))
julia> Flux.@layer :expand MaskAttention
julia> x = map(f->rand(Int32.(2:10), rand(8:16)), 1:32);
julia> x = reduce(hcat, rpad.(x, maximum(length.(x)), 1))
16×32 Matrix{Int32}:
7 7 7 6 9 5 2 8 9 7 3 4 … 9 2 6 4 8 2 9 9 4 5 3 3
2 5 3 10 7 7 9 7 9 6 2 3 7 7 7 10 8 7 7 8 7 10 10 7
6 3 5 5 3 4 2 4 9 9 2 7 9 4 8 4 9 6 3 9 5 4 2 3
julia> mask = permutedims(repeat((x .== 1), outer = [1, 1, 1, 1]), (1, 4, 3, 2))
16×1×1×32 BitArray{4}:
[:, :, 1, 1] =
0
0
0
julia> model = MaskAttention(MultiHeadAttention(16), mask)
MaskAttention(
MultiHeadAttention(16; nheads=8), # 1_024 parameters
Bool[0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1;;;; … ;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1], # 512 parameters
) # Total: 5 arrays, 1_536 parameters, 4.547 KiB.
julia> xx = randn32(16, 16, 32);
julia> model(xx) |> summary
"16×16×32 Array{Float32, 3}"
julia> model(xx) |> sum
NaN32
julia> findall(isnan, model(xx))
1280-element Vector{CartesianIndex{3}}:
CartesianIndex(1, 1, 11)
CartesianIndex(2, 1, 11)
CartesianIndex(3, 1, 11)
CartesianIndex(4, 1, 11)
CartesianIndex(5, 1, 11)
julia> loss, grads = Flux.withgradient(model) do m
sum(abs2, m(xx))
end
(val = NaN32, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), k_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), v_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing)), mask = nothing),))
julia> mask2 = rand(Bool, 16, 16);
julia> model2 = MaskAttention(MultiHeadAttention(16), mask2);
julia> model2(xx) |> sum
15.281105f0
julia> loss, grads = Flux.withgradient(model2) do m
sum(abs2, m(xx))
end
(val = 2290.0107f0, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[-12.446103 -4.3954763 … -23.305235 3.0878963; 17.259354 -6.7767124 … 25.798717 6.8329597; … ; -22.414658 -17.097672 … 93.836235 -9.206725; 29.134241 13.785215 … -41.797527 -23.159046], bias = nothing, σ = nothing), k_proj = (weight = Float32[10.827733 18.880678 … -35.907875 23.034206; 8.75228 23.103594 … 13.421799 -14.956886; … ; -71.6004 -13.324369 … -35.224113 61.402447; 27.011333 124.142815 … -21.26666 -63.877186], bias = nothing, σ = nothing), v_proj = (weight = Float32[34.734734 -24.169163 … 102.391365 -69.705055; -1.5541999 37.55378 … -69.58273 39.782215; … ; 5.4440746 -176.59694 … 171.69466 161.58585; -139.37288 181.93517 … -213.38739 -58.174618], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[-18.867039 0.40507406 … 92.7197 -36.512943; -30.138624 -14.419439 … -92.25858 -63.47702; … ; -89.09866 92.45964 … -212.48007 164.08275; 35.601555 -31.360823 … 128.91348 -104.323494], bias = nothing, σ = nothing)), mask = nothing),))
So I do want to vary this mask per batch, because sequences that are recruited by sampling into the next batch vary in length and padding varies. This minimal example is just one batch to show the issue. I tried clamping before softmax and NaNs are gone. The idea is to mask out from attention in encoder the padding tokens. If it's unusual that am I doing something wrong ? I have three different kinds of masks: padding mask in encoder (this mwe), casual mask in decoder and padding mask in loss function which affects only target sequence in decoder.
Hi,
The background is that in Encoder-Decoder model used for translation from "Attention Is All You Need" I desired to mask-out the padding in sentence passed to Encoder's MultiHeadAttention, but I notice that the computed for the mask
-neginf
based on the logits eltype might cause some issues and lead to NaN.Minimal example is provided here https://github.com/mashu/NaNTracker.jl The result is
I hope it's not an issue with my understanding of how mask should look like, but to be honest documentation in Flux could use a couple of examples for this particular use-case in addition to just
make_causal_mask
.