Also it would be nice to have the attn_bias dimension and requirements in the doc (like the multiple of 8 contiguous initialization described in the error message if not done):
To use an attn_bias with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])smallkF is not supported because:
bias with non-zero stride not supported
(Since it's a documentation thing I do not fill the other parts of the description template)
🐛 Bug
The documentation says one thing that doesn't match the equivalent code:
equivalent code:
attn = query @ key.transpose(-2, -1)
query, key shapes:[B, M, H, K]
The equivalent code should be:
Also it would be nice to have the
attn_bias
dimension and requirements in the doc (like the multiple of 8 contiguous initialization described in the error message if not done):(Since it's a documentation thing I do not fill the other parts of the description template)