phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

Multihead Attention #64

Closed ofermeshi closed 1 year ago

ofermeshi commented 1 year ago

It seems like the implementation of MultiheadAttention is not consistent with the "Multi-Head Attention" figure. In particular, the projection: self.qkv_proj = nn.Dense(3*self.embed_dim,...) Should actually be: self.qkv_proj = nn.Dense(3*self.embed_dim*self.num_heads,...) Am I missing something?

[this would also require to change the line: values = values.reshape(batch_size, seq_length, self.embed_dim) to: values = values.reshape(batch_size, seq_length, -1) ] Thanks.

phlippe commented 1 year ago

Hi, in Multi-Head Attention, it is common to split the embedding dimension over the heads. Therefore, each head uses a key dimension of embed_dim // num_heads, and the full output after applying $W^{O}$ is again embed_dim.

ofermeshi commented 1 year ago

Thanks for clarifying. Maybe worth adding a comment on that in the text as from the figure it seems like each head handles all input dimensions. Also, this requires that num_heads is a divisor of embed_dim.