ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

Split layernorm into two layers #299

Closed danbraunai-apollo closed 6 months ago

danbraunai-apollo commented 6 months ago

We can split layernorm into two sequential nn.Modules:

  1. The first will calculate the full denominator (standard deviation, including epsilon) of the residual stream (UPDATE: DO NOT TAKE THE SQUARE ROOT, JUST THE VAR + EPSILON) and return it, along with the raw residual stream and, in the case of pythia DualLayerNorm, the attn_resid.
  2. The second will calculate the layernorm after taking the std as input

The idea is that a layernorm layer induces high connectivity, but all of this connectivity is the result of the std, rather than the mean. So we should be able to isolate this connectivity to one layer.