google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

How to calculate empirical NTK of model being used in a classification task #68

Open uditsaxena opened 4 years ago

uditsaxena commented 4 years ago

For a model being used for classification with k classes, for n datapoints, the NTK should be of the size nk X nk. How would we get that with neural-tangents?

Currently, I'm able to get a n X n matrix.

romanngg commented 4 years ago

Set the trace_axes=() (see more details on this argument in https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html - happy to elaborate if needed!)

uditsaxena commented 4 years ago

Hey @romanngg - thanks for a quick response. That helped!

I tried to play with the trace_axes option. For n = 64, and k = 10, setting trace_axes to :

Shouldn't I be expecting something along the lines of 640 x 640 ? How would I extrapolate from this to what I need?

I'm sure I'm getting something wrong here. I would appreciate it if you could elaborate. Thanks!

romanngg commented 4 years ago

To clarify - you need to set trace_axes=() (empty tuple, not 0 or 1).

In general, if your outputs f1 and f2 have shapes (N1_0, N1_1, N1_2, ..., N1_K), (N2_0, N2_1, N2_2, ..., N1_K),

then the output kernel will have shape (N1_0, N2_0, N1_1, N2_1, N1_2, N2_2, ..., N1_K, N2_K) (not 2D, but rather 2*K-D),

BUT it will have pairs of axes having a subscript i missing if i is in trace_axes (and a similar mechanism with diagonal_axes). Lmk if this helps!

uditsaxena commented 4 years ago

Ah - that definitely helps.

With the empty parens, I get an output kernel with shape (n, n, k, k). The way I understand this is for an (i, j, k, k) the pair (i, j) refers to the i and j data points, and the k x k submatrix refers to their logits.

Does that sound alright to you?

romanngg commented 4 years ago

Yep, correct!

uditsaxena commented 4 years ago

Awesome. Thanks!

PS: I was trying to tag this as a question, but I wasn't able to.

uditsaxena commented 4 years ago

Also, I just went from a batch of 64 to a batch 256. Computing the output empirical NTK for n = 256 and k = 10 takes about 7 times longer (4.12 sec -> 27 sec)

How would you suggest I optimize this? If I have to run this for a lot of epochs, calculating the empirical NTK at each epoch, (and not only at the beginning of training) might take a bit long. All options are on the table for now.

romanngg commented 4 years ago

Likely only repo owners are allowed to do this, I don't see an option to let users set it...

Sadly the increased time is expected (in fact it would grow quadratically with batch size). See https://github.com/google/neural-tangents/issues/30 for ongoing discussion about performance.

Depending on your application, you may want to compute the empirical NTK with a single output logit to gain yourself an extra factor of k x k. Note that in many use-cases (if you have a stax.Dense layer on top), all those k x k tensor slices will converge to a constant-diagonal matrix in the infinite-width/sample limit, so you may be justified in computing an NTK for a single logit only, if your goal is to approximate the infinite-width/sample behavior.

uditsaxena commented 4 years ago

Okay - what you're saying is that in the infinite width/sample limit, we may be justified in computing the n X n matrix (which is computed for a single logit) instead of computing the whole n X n X k X k matrix since both converge to the same constant-diagonal matrix. Got it - that should help with the speed up.

How do you think that would change for sparse layers/networks though? Wouldn't the empirical NTK be more accurate for sparse networks if computed across all logits as compared to only a single logit? Maybe there's no answer to that question yet.

Re: #30 , I did comment on that earlier yesterday. I'm not sure I follow the method of accumulating the empirical output NTK since that ignores cross-layer weight covariances. Unless we're doing that on purpose, which probably translates to accumulating the diagonal matrix (same as what we do here above using the diagonal_axes() option) for a single logit.

romanngg commented 3 years ago

Re accuracy, it likely indeed depends on how you measure its accuracy precisely. For example, if your measure is how close the empirical kernel in terms of Frobenius norm to the infinite-width, infinite-sample NTK, or only infinite-sample but finite-width NTK, I think you would still get better accuracy / FLOPS if you use a single logit than multiple. If you want the best linear approximation to your finite-width network, then having multiple logits is more accurate.

Re #30, I have not looked into the per-layer implementation, but AFAIK "cross-layer weight covariances" are not needed in NTK, in fact no cross-weight covariances are present in the expression: df/dp(p, x1) df/dp(p, x2)^T = Sum over all individual scalar w in p [df/dw(p, x1)*df/dw(p, x2)] so you only need to compute covariances between [gradients of outputs wrt] same-weights, different x1 and x2. (in contrast to df/dp(p, x1)^T df/dp(p, x2), which would be a #p x #p matrix of cross-weight covariances). But perhaps this is a trivial point and not what you mean, I may need to look at their code more to better understand it...

Also, I've just pushed https://github.com/google/neural-tangents/commit/f15b6528a47a73b1940f069309e69111b5235e13 which should make computing empirical NTK faster wrt to n, especially for CNNs, so hopefully this will also help (you'd need to pass something like vmap_axes=0 to your nt.empirical_ntk_fn - see https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html for more details).