google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Warning: hk.LayerNorm when used in transformer decoder causes violation of autoregressive property #700

Closed hrbigelow closed 1 year ago

hrbigelow commented 1 year ago

This is not a bug in hk.LayerNorm or haiku - but rather, an unexpected result of using hk.LayerNorm in transformers.

When used in transformers, hk.LayerNorm introduces an information channel across token positions, which violates the autoregressive condition.

In the following, I have a single LayerNorm accepting a 2-D input of shape [batch, context]. In a transformer decoder, the autoregressive condition requires that any output position (context) not attend to any position to its right. I'm using an out_grads probe with 1s at a specific position 4 in the output context, and testing where the non-zero gradients are in the input.

If using hk.LayerNorm(axis=(0,1), ...), all positions of context have non-zero input gradients, which would violate the autoregressive condition. Using hk.LayerNorm(axis=(0,), ...) "fixes" this, although I'm not sure if that's the right way to go about this.

In any case, just wanted folks to be aware of this.

import haiku as hk
import jax
import jax.numpy as jnp

batch, other = 3, 7
rng_key = jax.random.PRNGKey(42)
input = jax.random.normal(rng_key, (batch, context))
out_grads = jnp.zeros((batch, context))
out_grads = out_grads.at[:,4].set(1.0)

def make_grads(axis, input, out_grads):
    def wrap_fn(*args):
        return hk.LayerNorm(axis=axis, create_scale=True, create_offset=True)(*args)

    rng_key = jax.random.PRNGKey(42)
    layer = hk.transform(wrap_fn)
    params = layer.init(rng_key, input)
    primal, vjp_fn = jax.vjp(layer.apply, params, rng_key, input)
    param_grad, rng_grad, input_grad = vjp_fn(out_grads)
    jnp.set_printoptions(precision=2, linewidth=150)
    print(input_grad)

print('input gradient for layer norm with axis=(0,)')
make_grads((0,), input, out_grads)

print('\n\ninput gradient for layer norm with axis=(0,1)')
make_grads((0,1), input, out_grads)

"""
input gradient for layer norm with axis=(0,)
[[ 0.00e+00  0.00e+00  0.00e+00  0.00e+00  9.83e-08  0.00e+00  0.00e+00]
 [ 0.00e+00  0.00e+00  0.00e+00  0.00e+00 -7.06e-08  0.00e+00  0.00e+00]
 [ 0.00e+00  0.00e+00  0.00e+00  0.00e+00 -2.77e-08  0.00e+00  0.00e+00]]

input gradient for layer norm with axis=(0,1)
[[-0.09 -0.21 -0.17 -0.17  1.03 -0.21 -0.17]
 [-0.15 -0.24 -0.15 -0.19  0.98 -0.11 -0.17]
 [-0.17 -0.11 -0.19 -0.14  0.99 -0.2  -0.15]]
"""
hrbigelow commented 1 year ago

I realize this may be common knowledge so I'll close this issue.