Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.16k stars 100 forks source link

Confusion about "the chunkwise recurrent representation of retention" #34

Open CHENHUI-X opened 11 months ago

CHENHUI-X commented 11 months ago

I have a question regarding "Chunkwise Recurrent Representation of Retention." The original expression in the paper is as follows: image

In your implementation, the code looks like this:

 r_i = (K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

The first part of this equation calculates the KV matrix for the current chunk, and then multiplies it by a scaling factor. My understanding is that, assuming we ignore batch size, the shapes of K and V for the current chunk are both (2,3). In other words, the current chunk contains 2 tokens, so the KV matrix should have a shape of (3,3). Then, based on your code, you multiply this KV matrix by the last row of the D matrix (shape is (2,2)), for example, if the D matrix is [[1, 0], [0.9, 1]], then V * D[-1].view(1, chunk_size, 1) becomes [[0.9], [1]], and these values are multiplied with the first and second rows of the V matrix to implement decay. However, when we take the inner product of the Q matrix for the chunk and the first half of R_i, it seems like both q tokens within the Q matrix are using the same decay factor, is that correct? In other words, for the same chunk, if we want to perform attention, the second q token should intuitively be multiplied by a decay factor (0.9) when attending to the first v token, but when the first q token operates on the first v token, it doesn't need this decay factor.

Additionally, for the second half of R_i, it seems that the entire chunk is considered as a whole, and R_i_1 is directly subjected to decay as a whole, and the decay occurs as many times as the length of the chunk.

There's another question I have regarding the cross-chunk calculations.

        #e[i,j] = gamma ** (i+1)
        e = torch.zeros(batch, chunk_size, 1)

        for _i in range(chunk_size):
            e[:, _i, :] = self.gamma ** (_i + 1)

        cross_chunk = (Q @ r_i_1) * e

In the code, the variable 'e' appears to play a role in decay as well. However, based on the code, the final result after calculating (Q @ r_i_1) might be something like [o1, o2, o3]^T, where each 'oi' is a row vector with D dimensions. What I'd like to point out is that, according to your code, 'o1' actually has the least decay, and 'o3' has the most decay. But intuitively, for the current Q, shouldn't the vector corresponding to 'o1' be the farthest from the q tokens within the current chunk? In other words, shouldn't the decay of 'o1' be the greatest? So, should the code be like this:

        #e[i,j] = gamma ** (i+1)
        e = torch.zeros(batch, chunk_size, 1)

        for _i in range(chunk_size):
            # e[:, _i, :] = self.gamma ** (_i + 1)
            e[:, _i, :] = self.gamma ** (chunk_size - _i)

        cross_chunk = (Q @ r_i_1) * e

This is very confusing to me. Is there a more detailed derivation or a clearer explanation of how equation (7) in the original article is obtained? Especially the exponential part of the decay factor, is the result of this calculation consistent with the result of completely parallel computation? Can someone help me with this?