Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.16k stars 100 forks source link

Passing Attention Masks #24

Open leffff opened 1 year ago

leffff commented 1 year ago

Hi! Is there a way to pass an attention mask like in transformers library or src_key_padding_mask in nn.Transformer? So that the model wouldn't "pay attention" to paddings?

leffff commented 1 year ago

Moreover how do you recommend to pool the output embeddings into a single vector? For example BERT uses a [CLS] token, that aggregates the information from the whole sequence. As I understood the last vector in the sequence encodes the information (like in RNNs).

Jamie-Stirling commented 1 year ago

Hi!

Thanks for your interest in this implementation. Please see the work of the original authors for more information, I'm best-placed to answer implementation-specific questions since I'm not an author. That said I can have a go at answering your questions.

Regarding padding, if your padding is placed after the input tokens, there's no need to mask the retention mechanism itself, since information can only flow forwards anyway. You'll probably want to mask out the losses during training though.

Regarding getting an embedding of an entire sequence, the recurrent state S (for the recurrent representation) and R (for chunk-wise) should share a large amount of mutual information with the preceding tokens, so they may serve as a useful vector (or rather matrix) representation of a sequence. That remains to be further investigated though.

leffff commented 1 year ago

Thanks for the answer, now it's clear to me! I'll just take the last non PAD token!