Closed ofermeshi closed 1 year ago
Hi, in Multi-Head Attention, it is common to split the embedding dimension over the heads. Therefore, each head uses a key dimension of embed_dim // num_heads
, and the full output after applying $W^{O}$ is again embed_dim
.
Thanks for clarifying. Maybe worth adding a comment on that in the text as from the figure it seems like each head handles all input dimensions. Also, this requires that num_heads is a divisor of embed_dim.
It seems like the implementation of MultiheadAttention is not consistent with the "Multi-Head Attention" figure. In particular, the projection: self.qkv_proj = nn.Dense(3*self.embed_dim,...) Should actually be: self.qkv_proj = nn.Dense(3*self.embed_dim*self.num_heads,...) Am I missing something?
[this would also require to change the line: values = values.reshape(batch_size, seq_length, self.embed_dim) to: values = values.reshape(batch_size, seq_length, -1) ] Thanks.