mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
500 stars 68 forks source link

bug in `nnf_multi_head_attention_forward` #1205

Closed MaximilianPi closed 1 week ago

MaximilianPi commented 3 weeks ago

Hi @dfalbel,

I think there is a small bug in nnf_multi_head_attention_forward related to the padding mask. From the padding mask documentation:

" (𝑁,𝑆)(N,S) where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged. "

Both masks (src and padding) can be passed as either boolean or float (values to be masked = -Inf).

Currently only the src mask is checked for the dtype:

if (!is.null(attn_mask)) {
    if (attn_mask$dtype == torch_bool()) {
      attn_output_weights$masked_fill_(attn_mask, -Inf)
    } else {
      attn_output_weights <- attn_output_weights + attn_mask
    }
  }

And not the key padding mask:

if (!is.null(key_padding_mask)) {
    attn_output_weights <- attn_output_weights$view(c(bsz, num_heads, tgt_len, src_len))
    attn_output_weights <- attn_output_weights$masked_fill(
      key_padding_mask$unsqueeze(2)$unsqueeze(3),
      -Inf
    )
    attn_output_weights <- attn_output_weights$view(c(
      bsz * num_heads,
      tgt_len,
      src_len
    ))
  }

Thus, float key padding masks are ignored:

src = torch_randn(10, 5, 2)
satt = nn_multihead_attention(embed_dim = 2L, batch_first = TRUE, num_heads = 1L)
satt(src, src, src)[[1]]$sum()

key_padding_mask = torch::distr_bernoulli(0.1)$sample(c(10, 5))$squeeze()$bool()
satt(src, src, src, key_padding_mask = key_padding_mask)[[1]]$sum()

key_padding_mask = torch::distr_bernoulli(0.1)$sample(c(10, 5))$squeeze()
key_padding_mask[key_padding_mask==1] = -Inf
satt(src, src, src, key_padding_mask = key_padding_mask)[[1]]$sum()

So the code (https://github.com/mlverse/torch/blob/d7c49776f331167733c96fb150143cf6f103c005/R/nnf-activation.R#L729C1-L740C4) should be probably changed to :

if (!is.null(key_padding_mask)) {
  attn_output_weights <- attn_output_weights$view(c(bsz, num_heads, tgt_len, src_len))
  if (key_padding_mask$dtype == torch_bool()) {
    attn_output_weights <- attn_output_weights$masked_fill(
      key_padding_mask$unsqueeze(2)$unsqueeze(3),
      -Inf
    )
  } else {
    attn_output_weights <- attn_output_weights + key_padding_mask
  }
  attn_output_weights <- attn_output_weights$view(c(
    bsz * num_heads,
    tgt_len,
    src_len
  ))
}
cregouby commented 3 weeks ago

Hello @MaximilianPi,

Good catch! Would you like to propose a P.R. ? (Additionnaly to your patch, edit the DESCRIPTION to add your name as contributor, and add a line in the NEWS as well to mention the fix)

MaximilianPi commented 3 weeks ago

Yes, I can do that