google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Significant difference in empirical NTK for batched and non-batched versions #122

Closed gortizji closed 2 years ago

gortizji commented 3 years ago

After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using nt.batch applied to nt.empirical_kernel_fn are significantly different depending on the batch size.

The code to reproduce this error is:

import jax.numpy as jnp
import flax.linen as nn
import functools
import jax
import neural_tangents as nt

class LeNet(nn.Module):
    kernel_size = (5, 5)
    strides = (2, 2)
    window_shape = (2, 2)
    num_classes = 1
    features = (6, 16, 120, 84, 1)
    pooling = True
    padding = "SAME"

    @nn.compact
    def __call__(self, x):
        conv = functools.partial(nn.Conv, padding=self.padding)
        x = conv(features=self.features[0], kernel_size=tuple(self.kernel_size))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))

        x = conv(features=self.features[1], kernel_size=tuple(self.kernel_size))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))

        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.features[2])(x)
        x = nn.relu(x)
        x = nn.Dense(self.features[3])(x)
        x = nn.relu(x)

        x = nn.Dense(self.num_classes)(x)
        return x

model_key, data_key = jax.random.split(jax.random.PRNGKey(42))
data = jax.random.normal(data_key, [500, 32, 32, 3])
model = LeNet()
init_params = model.init(model_key, jnp.zeros([1, 32, 32, 3]))

# Compute NTK Gram matrix using the fully parallel version
kernel_full_fn = nt.batch(
    nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
    batch_size=500,
    device_count=-1,
    store_on_device=False,
)
K_full = kernel_full_fn(data, None, "ntk", init_params)

# Compute NTK Gram matrix using minibatches
kernel_batch_fn = nt.batch(
    nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
    batch_size=100,
    device_count=-1,
    store_on_device=False,
)
K_batch = kernel_batch_fn(data, None, "ntk", init_params)

# Compute difference between two matrices. It should technically be 0.
print("Average error per entry:",  jnp.linalg.norm(K_full - K_batch) / K_full.size)

Surprisingly, if I run this with my old environment I get an average error of the order of 1e-8, while with the new environment the error is of the order of 1e-1. Also, this error remains exactly the same as long as batch_size<data.shape[0].

My old enviornment consisted of:

python=3.7.4
jax=0.2.8
jaxlib=0.1.57+cuda102
flax=0.3.0
neural-tangents=0.3.7

and my new environment has:

python=3.7.4
jax=0.2.19
jaxlib=0.1.70+cuda102
flax=0.3.4
neural-tangents=0.3.7
sschoenholz commented 3 years ago

Thanks for reaching out! Unfortunately I wasn't able to reproduce the issue you're having. Here's a colab notebook (note: I had to make the model slightly smaller to fit in public colab GPU memory).

All of the versions seem similar to your current environment. However, I notice your gpu drivers seem a little out-of-date (version 10.2 instead of version 11.2) is it possible that is the issue?

romanngg commented 3 years ago

Btw also consider trying implementation=1, we often find it faster than 2 for convolutions, especially if you only have one output logit (2 scales better with number of logits, but this shouldn't matter if you only have one).

gortizji commented 3 years ago

Thanks for the quick answers! Curiously, it seems that using implementation=1 does make things more stable, i.e., error of around 1e-8. Still for small batches of around 10 the error climbs up to 1e-5, but it is definitely not 0.1.

On the other hand, I can also not reproduce this strange behaviour on the Google colab. I will try to update my cuda version, and get back to you.

gortizji commented 3 years ago

After updating to CUDA-11.4, I can confirm that the issue indeed only happens on the old CUDA version. With this new version, both implementation=1 and implementation=2 yield an error of the order of 1e-8, regardless of the batch_size.

In fact, it seems that support for CUDA-10.2 will fade in the next release of JAX. Even if this happens soon, I would still recommend directly specifying CUDA-11.x as a dependency of neural-tangents. I do not know what was the root cause for that very strange behaviour, but I am worried it might silently break other functionalities of the library when using it with the wrong CUDA version.

In any case, thank you very much for all your help! You were really helpful.

romanngg commented 2 years ago

Thanks a lot for figuring this out! Just pushed a release (https://github.com/google/neural-tangents/releases/tag/v0.3.9) bumping up our minimum JAX version to 0.2.25, which itself should only work with CUDA-11 and higher, so hopefully this should be fixed! Please feel free to re-open if the issue remains