fkodom / yet-another-retnet

A simple but robust PyTorch implementation of RetNet from "Retentive Network: A Successor to Transformer for Large Language Models" (https://arxiv.org/pdf/2307.08621.pdf)
MIT License
101 stars 15 forks source link

Some issues regarding _build_decay_mask. #24

Closed Doraemonzzz closed 10 months ago

Doraemonzzz commented 10 months ago

Thank you for your implementation, but I have encountered a bug when using the code. There is a major problem in the function _build_decay_mask where the last element of decay_gammas is set to 1. This causes all elements of (decay_gammas**distance)[-1] to be 1 (since 1 ** float("inf") = 1), which can lead to information leakage. Here is a suggested modification:

    # Set the upper-triangular distances to infinity, so that only *past* keys
    # can affect the current query.  (Setting distance to infinity ensures that
    # the decay matrix is 0 for those positions, since x^(inf) = 0 when -1 < x < 1.
    distance_mask = torch.ones_like(distance, dtype=torch.bool).triu_(diagonal=1)
    # distance = distance.masked_fill_(distance_mask, "inf")

    distance = rearrange(distance, "n s -> () n s")
    decay_gammas = rearrange(decay_gammas, "h -> h () ()")

    decay_mask = decay_gammas**distance
    decay_mask = decay_mask.masked_fill_(distance_mask, 0)
    return decay_mask
fkodom commented 10 months ago

@Doraemonzzz Can you elaborate a bit? I'm double-checking the _build_decay_mask function now, and AFAIK it's behaving as expected. All of the upper-triangular elements are zero, so that the output does not depend on future tokens. All elements on the diagonal are 1, as you described, and lower-triangular elements are in the range (0, 1).

import torch

from yet_another_retnet.retention import _build_decay_mask

mask = _build_decay_mask(
    num_heads=1,
    query_length=4,
    key_length=4,
    dtype=torch.float32,
)
print(mask)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9688, 1.0000, 0.0000, 0.0000],
         [0.9385, 0.9688, 1.0000, 0.0000],
         [0.9091, 0.9385, 0.9688, 1.0000]]])

It's possible that I misunderstood the issue you're describing. Please feel free to elaborate, or correct anything that I mischaracterized above.

Doraemonzzz commented 10 months ago

@fkodom Sorry for the late reply. You can use the following code to reproduce the bug. I suspect that the reason is the range of representation of bf16:

import torch

from yet_another_retnet.retention import _build_decay_mask

mask = _build_decay_mask(
    num_heads=8,
    query_length=4,
    key_length=4,
    dtype=torch.bfloat16,
)
print(mask)

Output:

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9688, 1.0000, 0.0000, 0.0000],
         [0.9375, 0.9688, 1.0000, 0.0000],
         [0.9102, 0.9375, 0.9688, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9805, 1.0000, 0.0000, 0.0000],
         [0.9609, 0.9805, 1.0000, 0.0000],
         [0.9414, 0.9609, 0.9805, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9844, 1.0000, 0.0000, 0.0000],
         [0.9688, 0.9844, 1.0000, 0.0000],
         [0.9531, 0.9688, 0.9844, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9922, 1.0000, 0.0000, 0.0000],
         [0.9844, 0.9922, 1.0000, 0.0000],
         [0.9766, 0.9844, 0.9922, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9922, 1.0000, 0.0000, 0.0000],
         [0.9844, 0.9922, 1.0000, 0.0000],
         [0.9766, 0.9844, 0.9922, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9961, 1.0000, 0.0000, 0.0000],
         [0.9922, 0.9961, 1.0000, 0.0000],
         [0.9883, 0.9922, 0.9961, 1.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9961, 1.0000, 0.0000, 0.0000],
         [0.9922, 0.9961, 1.0000, 0.0000],
         [0.9883, 0.9922, 0.9961, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000]]], dtype=torch.bfloat16)
fkodom commented 10 months ago

@Doraemonzzz Interesting. You're right, the problem appears for bfloat16 and bfloat32 values, but not for float16 or float32. I didn't realize the range of bf16/bf32 could cause that problem. 🤷

Great find! I just pushed/release a fix (link), following your example from earlier.