We can split layernorm into two sequential nn.Modules:
The first will calculate the full denominator (standard deviation, including epsilon) of the residual stream (UPDATE: DO NOT TAKE THE SQUARE ROOT, JUST THE VAR + EPSILON) and return it, along with the raw residual stream and, in the case of pythia DualLayerNorm, the attn_resid.
The second will calculate the layernorm after taking the std as input
The idea is that a layernorm layer induces high connectivity, but all of this connectivity is the result of the std, rather than the mean. So we should be able to isolate this connectivity to one layer.
We can split layernorm into two sequential nn.Modules:
full denominator (standard deviation, including epsilon) of the residual stream(UPDATE: DO NOT TAKE THE SQUARE ROOT, JUST THE VAR + EPSILON) and return it, along with the raw residual stream and, in the case of pythia DualLayerNorm, the attn_resid.The idea is that a layernorm layer induces high connectivity, but all of this connectivity is the result of the std, rather than the mean. So we should be able to isolate this connectivity to one layer.