This is not a bug in hk.LayerNorm or haiku - but rather, an unexpected result of using hk.LayerNorm in transformers.
When used in transformers, hk.LayerNorm introduces an information channel across token positions, which violates the autoregressive condition.
In the following, I have a single LayerNorm accepting a 2-D input of shape [batch, context]. In a transformer decoder, the autoregressive condition requires that any output position (context) not attend to any position to its right. I'm using an out_grads probe with 1s at a specific position 4 in the output context, and testing where the non-zero gradients are in the input.
If using hk.LayerNorm(axis=(0,1), ...), all positions of context have non-zero input gradients, which would violate the autoregressive condition. Using hk.LayerNorm(axis=(0,), ...) "fixes" this, although I'm not sure if that's the right way to go about this.
In any case, just wanted folks to be aware of this.
This is not a bug in
hk.LayerNorm
or haiku - but rather, an unexpected result of usinghk.LayerNorm
in transformers.When used in transformers,
hk.LayerNorm
introduces an information channel across token positions, which violates the autoregressive condition.In the following, I have a single LayerNorm accepting a 2-D input of shape
[batch, context]
. In a transformer decoder, the autoregressive condition requires that any output position (context
) not attend to any position to its right. I'm using anout_grads
probe with 1s at a specific position 4 in the output context, and testing where the non-zero gradients are in the input.If using
hk.LayerNorm(axis=(0,1), ...)
, all positions of context have non-zero input gradients, which would violate the autoregressive condition. Usinghk.LayerNorm(axis=(0,), ...)
"fixes" this, although I'm not sure if that's the right way to go about this.In any case, just wanted folks to be aware of this.