Open lxww302 opened 8 months ago
The formula is out of PaLM paper. It's been a while that I looked at this but my initial reaction is that the additional *2 comes from relative positional embeddings used in PaLM. (So i think you'd be correct)
In attention forward, we have Q*K^T*V
, which is (L x d) x (d x L) x (L x d)
.
(L x d) x (d x L)
: we have L*L*2d
flops (2 = 1 multiplication + 1 addition)(L x L) x (L x d)
: we have L*d*2L
flopsso for each token, we have (L*L*2d + L*d*2L) / L
flops, which is 4dL
, where d equals to H*Q
defined in the equation.
When combining backward, we got 12*Layer*H*Q*T
. Does this make sense?
in Scaling Laws for Neural Language Models,
C_forward = 2N + 2 * n_layer* n_ctx * d_attn
Since backwards pass is approximately twice the compute as the forwards pass. the total compute should be
6N + 6 * n_layer* n_ctx * d_attn
.while in https://github.com/karpathy/nanoGPT/blob/325be85d9be8c81b436728a420e85796c57dba7e/model.py#L296,
flops_per_token = 6*N + 12*L*H*Q*T
, should it beflops_per_token = 6*N + 6*L*H*Q*T
?