karpathy / minGPT

A minimal PyTorch re-implementation of the OpenAI GPT (Generative Pretrained Transformer) training
MIT License
19.4k stars 2.4k forks source link

What is the purpose of `c_proj` here? #135

Open brynhayder opened 2 months ago

brynhayder commented 2 months ago

https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L42

Why do we need an additional linear transformation after the MHA and before the MLP when the dimensions are the same?

(I understand that this is how the initial transformer implementation was written, but I took this operation to be for dimension consistency between sequential attention operations. It seems superfluous here since the first linear in the MLP can already take linear combinations of the attention outputs.)

theicfire commented 2 months ago

I think the point is that the W^O can change the dimensionality, if the output of the concatenation is large/small (i.e. if dk or dv was a different dimension). Though the paper ultimately used parameters to make W^O square. It maybe helped the authors with trying out other parameters.

Related: https://github.com/karpathy/minGPT/issues/118#issuecomment-2095063510