Closed geajack closed 3 years ago
I think 10K x 10K NTK in one go will OOM, but in general evaluating 100 batches of 1K x 1K NTK should be quite doable (maybe minutes, but not hours), but ofc it depends on exact input/output dimension and your hardware.
I'm not sure how good is TF autodiff, but we do have some tricks that might not be available elsewhere (at least not when naively using JAX's autodiff), so may be worth a try.
For details see https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html, but in short
1) If you don't have batch-norm or any other sample-to-sample interactions in your NN, try setting vmap_axes=0
, this will evaluate Jacobians for each sample in separation. Apriori autodiff may not know this (and this isn't true for example if you use BatchNorm), and incur a higher cost than we do.
2) Try different implementaiton=1/2
settings, that may have substantially different performance for different tasks. implementaiton=2
uses implicit differentiation that results in a different underlying computation than naively instantiating the Jacobian and contracting it (and in fact it may give different results due to different order of contractions).
3) Depending on what exactly you need the NTK for, you may want to check trace_axes
and diagonal_axes
arguments, as well as consider evaluating an NTK with a single output logit (vs e.g. 10 for CIFAR10 classification). There's more details on this in readthedocs, and see also https://github.com/google/neural-tangents/issues/68
4) You can find some prior discussion and performance numbers on empirical NTK here, https://github.com/google/neural-tangents/issues/30#issuecomment-729342789, notably you can try running the colab https://colab.sandbox.google.com/gist/romanngg/66574cca1dc1a6a7781c14745aeb1141/empirical_ntk_speedup.ipynb to get a feeling for how fast or slow NT will be for your task.
Lmk if this helps or you have any other questions!
Also related: https://github.com/google/neural-tangents/issues/100
Okay, it seems that neural_tangents
is a significant improvement over my own code, enough that it might be enough to solve my problem.
To answer the question that I asked initially, here is how to compute the Gram matrix (train-train kernel) of a finite-width network with this library:
from jax import random
from neural_tangents import stax
import neural_tangents
import numpy as np
n_samples = 500
input_dimension = 784
x_train = np.random.normal(0, 1, (n_samples, input_dimension))
init_fn, f, _ = stax.serial(
stax.Dense(1000, parameterization="standard"),
stax.Relu(),
stax.Dense(1, parameterization="standard")
)
key = random.PRNGKey(1)
_, weights = init_fn(key, x_train.shape)
kernel = neural_tangents.empirical_kernel_fn(
f,
implementation=1
)
gram = kernel(x_train, None, "ntk", weights)
print(gram.shape)
Here is the implementation I'm using with Tensorflow:
import numpy as np
import keras
import tensorflow
def tangent_feature(keras_model, xs):
n = len(xs)
x_variable = tensorflow.Variable(xs, dtype=tensorflow.float32)
with tensorflow.GradientTape() as tape:
output = keras_model(x_variable)
subgradients = tape.jacobian(output, keras_model.trainable_weights)
flattened = [tensorflow.reshape(sg, [n, -1]) for sg in subgradients]
gradients = tensorflow.concat(flattened, 1)
return gradients.numpy()
def gram_matrix(keras_model, x):
gradients = tangent_feature(keras_model, x)
return gradients.dot(gradients.T)
model = keras.models.Sequential(
[
keras.layers.Dense(1000, activation="relu"),
keras.layers.Dense(1, activation="linear")
]
)
model.compile()
n_samples = 500
input_dimension = 784
model.build(input_shape=(1, input_dimension))
x_train = np.random.normal(0, 1, (n_samples, input_dimension))
gram = gram_matrix(model, x_train)
print(gram.shape)
On inputs of dimension 784 (that of MNIST) and 500 training examples, my code ran for a full minute before I just killed it. The neural_tangents
code runs in about 4 seconds for the same parameters. I verified separately that the two scripts are computing the same Gram matrix within tolerable error bars (by creating a Keras model and then copying its weights over to an equivalent Stax model - see here for some example code).
I should be able to break the problem into batches and compute by 10k x 10k Gram matrix in about half an hour, taking advantage of the symmetry of the matrix.
Awesome! FYI, I believe you can further speed up your case by using vmap_axes
and jit
, something like
kernel = jax.jit(neural_tangents.empirical_ntk_fn(f, vmap_axes=0))
(when you jit
, the first call may take a bit longer to compile, but subsequent calls should be much faster)
I think you forgot the static_argnums
in that JIT - the correct call is
kernel = jax.jit(neural_tangents.empirical_kernel_fn(f, vmap_axes=0), static_argnums=[2])
Otherwise it complains that argument 2 of the kernel function you're trying to jit is a string.
The vmap_axes=0
argument makes a big difference - does it reduce memory footprint as well? Without vmap_axes=0
I can do five 500x500 grams in 40 seconds. With vmap_axes=0
I can do the same number of 1000x1000 grams in 1 minute, about 2.6 times more work-per-second. Jitting does not seem to make a big difference, actually it made the code take slightly longer in some cases (when computing kernels in a loop), so I'll probably leave it. Thanks for the tip!
Ah good catch! I accidentally replaced empirical_kernel_fn
with empirical_ntk_fn
, and the latter does not take get
argument and computes just the NTK, so doesn't need static argnums.
Re jit, I'm a bit surprised that it doesn't help or makes it slower, are you by any chance jitting inside the loop? Ideally you want to jit once kernel = jax.jit(neural_tangents.empirical_ntk_fn(f, vmap_axes=0))
and then call kernel
on different inputs. A common reason for jitting to make things slower is when it recompiles the function repeatedly. This can happen if either you redefine the function yourself inside the loop (e.g. by calling jit(kernel(x_train, None, weights))
instead of defining kernel = jit(kernel)
once outside the loop), or if the shapes of inputs to your function change (it recompiles for every new shape).
Here's my code - for me this takes about 42 seconds with JIT and about 39 seconds without JIT:
from jax import random
import jax
from neural_tangents import stax
import neural_tangents
import numpy as np
n_samples = 1000
input_dimension = 784
init_fn, f, _ = stax.serial(
stax.Dense(1000, parameterization="standard"),
stax.Relu(),
stax.Dense(1, parameterization="standard")
)
key = random.PRNGKey(1)
_, weights = init_fn(key, (1, input_dimension))
kernel = neural_tangents.empirical_ntk_fn(
f,
implementation=1,
vmap_axes=0
)
for i in range(3):
x_train = np.random.normal(0, 1, (n_samples, input_dimension))
gram = kernel(x_train, None, weights)
print(gram[0,0])
I tried timing the individual iterations of the loop too - they each take about the same amount of time with JIT.
Hm, I wonder if this is just due to only 3 iterations in the loop, i.e. you spend some time to compile the function, but don't use it enough to make up for it?
I ran this in colab https://colab.research.google.com/gist/romanngg/7f7ab03b0b7e0b9b8889bcd53d2552e6/jit_timing.ipynb
and if I loop over 30 samples, timing seems to be about 26 seconds for JIT and 30 seconds for non-JIT. I guess it's fair that JIT may not help much in a simple one-layer network, but I'd still recommend using it, as especially for deeper network I believe it will save you a lot of RAM compared to non-jitted version.
And regarding vmap_axes
- yes, for implementation=1
it should reduce the memory requirement by order of batch_size
. For implementation=2
, interaction is less understood, it does not reduce memory I believe, and has a small effect on time.
Suppose I have a regular old neural network with its weights set to some values. Then the NTK k(x, y) is well-defined as the dot product of df/dw at each input, that is, the dot product of the gradients of the network's output with respect to the weights. In some of my own code I'm computing this kernel using Keras with tensorflow's automatic differentiation capabilities, but it chokes on even moderate-sized models (trying to compute the train-train kernel with 1000 neurons and 10k training inputs).
I've been looking at
neural_tangents.utils.empirical
, but I thought I'd ask - does this codebase contain some magical code that will allow me to compute my Gram matrix in a reasonable amount of time and memory?