google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 225 forks source link

kernel function decorated with batch doesn't return the right shape #87

Closed qixuanf closed 3 years ago

qixuanf commented 3 years ago

As can be seen in the example below, I expected that decorating a kernel function with batch doesn't change the return shape of the kernel function, i.e. both res_5 and res_10 below should have shape (50, 50, 10).

import numpy as np

import haiku as hk
import jax
import neural_tangents as nt

def mlp(x):
    net = hk.Sequential(
        [
            hk.Flatten(),
            hk.Linear(300),
            jax.nn.relu,
            hk.Linear(100),
            jax.nn.relu,
            hk.Linear(10),
        ]
    )
    return net(x)

key = jax.random.PRNGKey(42)
net_transformed = hk.without_apply_rng(hk.transform(mlp))
params = net_transformed.init(key, np.zeros((1, 1, 28, 28)))

test_x = np.random.rand(50, 1, 28, 28)
kernel_fn = nt.empirical_ntk_fn(f=net_transformed.apply, trace_axes=(), diagonal_axes=(-1,),
                                vmap_axes=0, implementation=2)
batched_fn = nt.batch(kernel_fn, device_count=-1, batch_size=10)
res_10 = batched_fn(test_x, test_x, params)
print(res_10.shape)  # (25, 10, 10, 10)

kernel_fn = nt.empirical_ntk_fn(f=net_transformed.apply, trace_axes=(), diagonal_axes=(-1,),
                                vmap_axes=0, implementation=2)
batched_fn = nt.batch(kernel_fn, device_count=-1, batch_size=5)
res_5 = batched_fn(test_x, test_x, params)
print(res_5.shape)  # (100, 5, 5, 10)

assert not np.allclose(res_10.flatten(), res_5.flatten())

The version of neural_tangent is 0.3.5

sschoenholz commented 3 years ago

Hi @qixuanf, thanks for the message. You've definitely found a bug! I think the issue is that we introduced diagonal_axes into the empirical kernel but didn't fully build out support for it in the batching code. For now, can you try setting diagonal_axes=()? I'll try to submit a fix soon, but it may take slightly longer than normal since NeurIPS is going on at the moment.

qixuanf commented 3 years ago

Thanks! Sure, I will not use this argument for the moment.

dvtailor commented 3 years ago

Is there an update on this. Setting diagonal_axes=() is not really a perfect solution as in many settings it increases evaluation time quite considerably, and unnecessary when off-diagonal entries are not needed.

sschoenholz commented 3 years ago

Thanks for pinging this, it had slipped my mind. Batching should now work correctly with diagonal_axes and trace_axes for empirical kernels as of fd1611660c87edcb0c2e50403f691b60d2cc252b. If something still seems off, please don't hesitate to let us know!

romanngg commented 3 years ago

Will close this as fixed, please create a new issue if something is still not working!