Closed Doraemonzzz closed 10 months ago
@Doraemonzzz Can you elaborate a bit? I'm double-checking the _build_decay_mask
function now, and AFAIK it's behaving as expected. All of the upper-triangular elements are zero, so that the output does not depend on future tokens.
All elements on the diagonal are 1, as you described, and lower-triangular elements are in the range (0, 1).
import torch
from yet_another_retnet.retention import _build_decay_mask
mask = _build_decay_mask(
num_heads=1,
query_length=4,
key_length=4,
dtype=torch.float32,
)
print(mask)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9688, 1.0000, 0.0000, 0.0000],
[0.9385, 0.9688, 1.0000, 0.0000],
[0.9091, 0.9385, 0.9688, 1.0000]]])
It's possible that I misunderstood the issue you're describing. Please feel free to elaborate, or correct anything that I mischaracterized above.
@fkodom Sorry for the late reply. You can use the following code to reproduce the bug. I suspect that the reason is the range of representation of bf16:
import torch
from yet_another_retnet.retention import _build_decay_mask
mask = _build_decay_mask(
num_heads=8,
query_length=4,
key_length=4,
dtype=torch.bfloat16,
)
print(mask)
Output:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9688, 1.0000, 0.0000, 0.0000],
[0.9375, 0.9688, 1.0000, 0.0000],
[0.9102, 0.9375, 0.9688, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9805, 1.0000, 0.0000, 0.0000],
[0.9609, 0.9805, 1.0000, 0.0000],
[0.9414, 0.9609, 0.9805, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9844, 1.0000, 0.0000, 0.0000],
[0.9688, 0.9844, 1.0000, 0.0000],
[0.9531, 0.9688, 0.9844, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9922, 1.0000, 0.0000, 0.0000],
[0.9844, 0.9922, 1.0000, 0.0000],
[0.9766, 0.9844, 0.9922, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9922, 1.0000, 0.0000, 0.0000],
[0.9844, 0.9922, 1.0000, 0.0000],
[0.9766, 0.9844, 0.9922, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9961, 1.0000, 0.0000, 0.0000],
[0.9922, 0.9961, 1.0000, 0.0000],
[0.9883, 0.9922, 0.9961, 1.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000],
[0.9961, 1.0000, 0.0000, 0.0000],
[0.9922, 0.9961, 1.0000, 0.0000],
[0.9883, 0.9922, 0.9961, 1.0000]],
[[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000]]], dtype=torch.bfloat16)
@Doraemonzzz Interesting. You're right, the problem appears for bfloat16
and bfloat32
values, but not for float16
or float32
. I didn't realize the range of bf16/bf32 could cause that problem. 🤷
Great find! I just pushed/release a fix (link), following your example from earlier.
Thank you for your implementation, but I have encountered a bug when using the code. There is a major problem in the function
_build_decay_mask
where the last element ofdecay_gammas
is set to 1. This causes all elements of(decay_gammas**distance)[-1]
to be 1 (since 1 ** float("inf") = 1), which can lead to information leakage. Here is a suggested modification: