Open MichaelMMeskhi opened 1 year ago
Hey Michael, could you please explain what's 60,000 in "my kernels can be of size (100, 60000)"? And by kernels, do you mean NTKs? As far as I know, when computing the infinite-width NTKs (as it is the case in the notebook that you shared), the input width for the network doesn't affect the performance (neither computational nor memory) of computing the kernels. For instance, you can see that the kernel_fn
doesn't need parameters of the net as the input.
As far as I know, there's no easy way to reduce the memory footprint of computing infinite-width NTKs, except maybe deriving the exact NTK formulas analytically and computing them directly (as they did in https://github.com/LeoYu/neural-tangent-kernel-UCI for Convolutional-NTKs) as opposed to this repo which does it compositionally (for the sake of generality).
However, I would suggest using empirical NTKs instead of infinite-width NTKs. Particularly about this work that you have suggested, if I understand things correctly, they are treating the network as the fixed object and data as the trainable parameters, as opposed to data as the fixed object and network's weights as the trainable parameters. In this case, I highly suspect that using an empirical NTK with trained weights at the end of the training procedure would produce better results than using the infinite-width NTK, as the generalization of a finite-width network at the end of (proper) training is often better than that of a corresponding infinite-width network.
If you decide to use empirical NTKs, I would again suggest using pseudo-NTK (https://proceedings.mlr.press/v202/mohamadi23a.html), which approximates empirical-NTK almost perfectly at the end of training, and is orders of magnitude cheaper, both computational and memory complexity-wise. It's shown in the paper that you can use pNTK to compute full 50,000 x 50,000 kernels on datasets like CIFAR-10 with ResNet18 network on a reasonable machine available in academia.
Let me know if it helps!
Hi @mohamad-amin thank you for your feedback. I will definitely look into that but at this moment I have to finalize the project as is.
So looking into the code better, I understand that the limitation isn't in computing k(x,x)
but rather doing backprop. If I understand correctly, ntk.batch
is mainly for kernel computation (forward pass). Is there anything to break up gradient calculation within NTK? If not I assume that is something to be done via JAX.
Hey Michael,
Unfortunately I'm not an expert on autograd, and I don't know many tricks in this regard. I just skimmed the code, and it seems like in the loss_acc_fn
they use sp.linalg.solve
to compute the kernel regression predictions. I'm not exactly sure how the gradient for this step is computed, but if it's taking gradient of iterative LU operations, that could require a lot of memory. (also see https://github.com/google/jax/issues/1747)
I'd suggest replacing the np.linalg.solve
in that function with the cholesky solve alternative (see https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_solve.html and https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_factor.html) for possible improvements both memory-wise and speed-wise.
And yes, nt.batch
is for computing the NTK kenel in batches (see https://neural-tangents.readthedocs.io/en/latest/batching.html).
I am trying to improve up this paper https://arxiv.org/pdf/2011.00050.pdf where they optimize some subset using NTK. They optimize their loss in batches. Smaller batches for more complex architectures (e.g. Conv, Myrtle). In my case, I am unable to optimize in batches and have to load entire dataset into memory. For instance, I am optimizing a subset of 100 parameters where my kernels can be of size (100, 60000). Using a single FC NTK of width 1024, things runs fine. But when I try to use a 2 layer Conv of width 64, OOM errors occur. Are there ways to reduce memory footprint via NTK?
Their code is here kip_open_source.ipynb - Colaboratory
Thank you