google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Transformer Block Implementation gives NotImplementedError #203

Open esnvidia opened 4 months ago

esnvidia commented 4 months ago

Hi,

I implemented a basic transformer block with residual connections and am getting the following error:

NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].

It appears that it's due to stax.Identity()

Here is the implementation:

def FeedForwardNetwork(hidden_dim, output_dim):
    return stax.serial(stax.Dense(hidden_dim), stax.Relu(),
                       stax.Dense(output_dim)
                      )

AttnBlock = stax.serial(stax.FanOut(2),
                        stax.parallel(
                            stax.serial(
                                stax.GlobalSelfAttention(
                                   n_chan_out=1,
                                   n_chan_key=1,
                                   n_chan_val=1,
                                   pos_emb_type='SUM',
                                   W_pos_emb_std=1,
                                   # pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                   attention_mechanism='SOFTMAX',
                                    linear_scaling=True,
                                   n_heads=1)
                            ),
                            stax.Identity()
                        ),
                        stax.FanInSum()
                       )

def TransformerBlock(ff_dim, d_model):
    return stax.serial(AttnBlock,
                       stax.LayerNorm(),
                       stax.FanOut(2),
                       stax.parallel(
                           FeedForwardNetwork(ff_dim, d_model),
                          stax.Identity()
                       ),
                       stax.FanInSum(),
                       stax.LayerNorm()
                      )
def Transformer(num_layers,ff_dim, d_model):
    layers = []
    for _ in range(num_layers):
        layers.append(TransformerBlock(ff_dim, d_model))
    layers.append(stax.Dense(out_dim=1))
    return stax.serial(*layers)

num_layers = 1
ff_dim = 128
d_model = 256

init_fn, apply_fn, kernel_fn = Transformer(num_layers, ff_dim, d_model)

And then taking the example data from the cookbook:

key = random.PRNGKey(10)
train_points = 5
test_points = 50
noise_scale = 1e-1

target_fn = lambda x: jnp.sin(x)

key, x_key, y_key = random.split(key, 3)

train_xs = random.uniform(x_key, (train_points, 1), minval=-jnp.pi, maxval=jnp.pi)

train_ys = target_fn(train_xs)
train_ys += noise_scale * random.normal(y_key, (train_points, 1))
train = (train_xs, train_ys)

test_xs = jnp.linspace(-jnp.pi, jnp.pi, test_points)
test_xs = jnp.reshape(test_xs, (test_points, 1))

test_ys = target_fn(test_xs)
test = (test_xs, test_ys)

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnames='get')

kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = jnp.sqrt(jnp.diag(kernel))

where the error occurs in the kernel_fn calculation.

What is odd is that the ResBlock works in the cookbook:

ResBlock = stax.serial(
    stax.FanOut(2),
    stax.parallel(
        stax.serial(
            stax.Erf(),
            stax.Dense(512, W_std=1.1, b_std=0),
        ),
        stax.Identity()
        ,
    stax.FanInSum()
)

And it appears that with linear_scaling=True that the is_gaussian=True from this line: https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39

Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.