nanowell / Differential-Transformer-PyTorch

PyTorch implementation of the Differential-Transformer architecture for sequence modeling, specifically tailored as a decoder-only model similar to large language models (LLMs). The architecture incorporates a novel Differential Attention mechanism, Multi-Head structure, RMSNorm, and SwiGLU.
MIT License
45 stars 5 forks source link

the shape of q,k,v #3

Open zziC7 opened 1 month ago

zziC7 commented 1 month ago

Hello, I noticed that in your code, the projection method of q, k, v is self.W_q = nn.Linear(d_model, 2 * self.d_head * num_heads, bias=False)

However, in other repository I found they calculate q, k, v as: self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) code from this link

The shape difference leads to differences in subsequent differential attention calculations. So I wonder which code is the method in the paper, or are the two just different ways of writing it?

Thanks.

nanowell commented 1 month ago

Hello,

The difference you've noticed stems from different parameterization approaches to achieve the same mathematical formulation.

In the paper's notation (Equation 1), we have: [Q₁; Q₂] = XW^Q, where W^Q ∈ ℝ^(d_model × 2d)

self.W_q = nn.Linear(d_model, 2 * self.d_head * num_heads, bias=False)

This directly implements the paper's formulation by creating a projection matrix that outputs concatenated Q₁ and Q₂ in one operation.