Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Dimensions of forward_recurrent #36

Closed Qiu30 closed 7 months ago

Qiu30 commented 7 months ago

In MultiScaleRetention class, it is mentioned that 's_n_1s' has dimensions (batch_size, heads, head_size, head_size), while in SimpleRetention, 's_n_1' is defined as 's_n_1s[i]'. However, you mentioned that 's_n_1' has dimensions (batch_size, hidden_size, v_dim). Can you clarify this?

DinoMan commented 7 months ago

@Qiu30 Just had a closer look at the code (and the tests.py) and you need to note that s_n_1s is a list for MultiScaleRetention. What the comment means to say is that each element of the list has the shape (batch_size, heads, head_size, head_size). As for RetNet the state is a list of lists with each element being (batch_size, heads, head_size, head_size). So to summarize:

Retention --> Sn-1: (batch x head_dim x head_dim) MultiscaleRetention --> Sn-1s: List of with num_head elements (tensors) each with shape (batch x head_dim x head_dim) RetNet --> Sn-1s: List with num_layers elements. Each element is a list with elements tensors with shape (batch x head_dim x head_dim)

I hope this helps.

Jamie-Stirling commented 7 months ago

Hi all, thanks very much for raising this and identifying the issue. I'll update the comments when I get time.

Qiu30 commented 7 months ago

@DinoMan @Jamie-Stirling Thank you for your reply. I have the same idea as you, but I have a question, what is the initial value of s_n_1? I searched the paper and did not see the relevant initial value.

Jamie-Stirling commented 7 months ago

Hi, in the code I initialize this to zeros, however this detail is not mentioned in the paper. I'm not sure of the impact of the choice of the initial value on training, but setting to zeros ensures only the keys and values computed from the first token effect the state at t=1, akin to a transformer.

I would say, setting a nonzero constant or trainable value for the initial state is analogous to introducing a bias term, and so may affect the way the RetNet trains. Though I'm not an expert so it may be best to ask the authors of the original paper to make sure.

Qiu30 commented 7 months ago

@Jamie-Stirling I understand, thanks for your reply!