mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
500 stars 68 forks source link

fix key_padding_mask bug in `nnf_multi_head_attention_forward` #1208

Closed MaximilianPi closed 2 weeks ago

MaximilianPi commented 2 weeks ago

fix key_padding_mask bug in nnf_multi_head_attention_forward (#1205). Float masks are added to the attn_output_weights, while boolean masks fill attn_output_weights with -Inf.

dfalbel commented 2 weeks ago

Thanks @MaximilianPi !