lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.42k stars 261 forks source link

Should Attention class compute separate keys & values across heads? #99

Closed LWprogramming closed 1 year ago

LWprogramming commented 1 year ago

Currently in the Attention class:

inner_dim = dim_head * heads
...
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)

This seems to learn a separate embedding -> query mapping per head, but embedding -> key or value would be the same across heads, while the original attention paper says k and v should also be independent (section 3.2.2, bottom of page 4).

lucidrains commented 1 year ago

@LWprogramming very observant! this is actually using an updated technique from this paper

the technique was employed by both PaLM as well as AlphaCode. in other words, it will scale just fine, and save a ton of memory when doing decoding

LWprogramming commented 1 year ago

Oh man another paper to my reading list 😂

I'll make a PR after I finish reading through the rest of the code to comment all these little optimizations that weren't around in the original paper

lucidrains commented 1 year ago

@LWprogramming haha yea, the field is a science, so a lot of literature

biendltb commented 9 months ago

Hi @LWprogramming and @lucidrains,

Related to this topic, do you know why we omit the self-attention for the context in the transformer encoder before passing the context to the cross-attention?

In the implementation, I found that we directly pass the context to the cross-attention without doing self-attention as in the original paper: https://github.com/lucidrains/audiolm-pytorch/blob/1a888d2f462384baf5dc8b4782f39a40f59593b7/audiolm_pytorch/audiolm_pytorch.py#L503