microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

Multi-Scale Retention: Why include position embeddings explicitly? #48

Closed fkodom closed 1 year ago

fkodom commented 1 year ago

My question is about the RetNet paper, which leads to the implementation here...

Why include the positional embedding updates directly in the multi-scale retention layer, rather than just applying them to the RetNet inputs?

Screen Shot 2023-08-02 at 9 35 48 AM

Screen Shot 2023-08-02 at 9 29 37 AM

IMO, this seems overly specific to the language modeling use case. Other applications of retention/attention should be free to use whatever positional embeddings they need/want.

The retention formulation is still self-consistent (i.e. equivalent for parallel, recurrent, chunkwise) without explicitly including positional embeddings in the retention layer. See Equations (1) and (2):

Screen Shot 2023-08-02 at 9 41 15 AM

Instead of forcing positional embeddings into the retention formulation, we can just set A equal to the decay matrix D. The parallel/recurrent/chunkwise formulations are still equivalent, and we remove the hard-coded dependence on xPos embeddings in the retention layer.

Conceptually, I'm thinking of how to apply RetNet to other data domains (images, heterogeneous graphs, etc). In those cases, the xPos embeddings are not reflective of the actual position in the sequence (2D position in image, generic position within a graph, etc). Does it make sense to remove the explicit position embedding from the retention layer, or am I missing something?

sunyt32 commented 1 year ago

$e^{i\theta}$ works well on language modeling, and we set it as default. For other domains, we don't evaluate on them yet, and I agree that rotation may not be the best option. Also, maybe an optimization technique is needed, but setting them as learnable parameters naively will cause nan in gradients. You can try to adjust it manually or explore a usable method to optimize it.

donglixp commented 1 year ago

It depends on how you understand "position embeddings". For example, we can also add the position embeddings (such as "generic position within a graph") to the token embeddings, where the positions are regarded as attributes.

fkodom commented 1 year ago

Thanks! This is exactly what I was looking for. 😎