NolanoOrg / cformers

SoTA Transformers with C-backend for fast inference on your CPU.
MIT License
311 stars 29 forks source link

Benchmark effect of merging query and keys matrices in transformers #3

Open Ayushk4 opened 1 year ago

Ayushk4 commented 1 year ago

For certain architectures (like GPTJ and LLaMa), it may be possible to replace Query $Q$ and Key $K$ matrices by a single matrix - saving on 1 out of seven/eight matrix multiplications in the transformer. I don't see an obvious way of having this for GPT-NeoX and OPT.

Take a standard benchmark, run the model before and after merging Query and Key matrices.

---------- Following are the details: (How to write latex in GitHub?)---------- .T() denotes transpose

Consider the input representation $X = {x1, ... xi, ... xj, ... xn}$. qi = MatMul(Q, xi) kj = MatMul(K, xj)

score_i,j = MatMul(qi.T(), kj) = MatMul( MatMul(Q, xi).T(), MatMul(K, xj) ) = MatMul( MatMul(xi.T(), Q.T()), MatMul(K, xj) ) = MatrixChainMul(xi.T(), Q.T(), K, xj)

let QKMerge = MatMul(Q.T(), K)

score_i,j = MatrixChainMul(xi.T(), QKMerge, xj)

Ayushk4 commented 1 year ago

Above formula will have to be modified for rotary embeddings.