thegregyang / NTK4A

Code for the paper: "Tensor Programs II: Neural Tangent Kernel for Any Architecture"
https://arxiv.org/abs/2006.14548
98 stars 10 forks source link

[Question] Does NTK tend to train slower or faster than regular DNN? #1

Closed gaceladri closed 4 years ago

gaceladri commented 4 years ago

Hello,

I am reading your paper but I am not aware of the training/resources cost of this kind of kernel. Is it faster or slower to train than regular networks? Is it demanding more resources than regular?

Thanks!

thegregyang commented 4 years ago

Hi @gaceladri

Thanks for your interest. This is a great question. The short answer: for small datasets and simpler architectures, the NTK can train faster, but for larger datasets or more complicated architectures, the DNN can train faster.

The long answer: There are two steps to "training" an NTK of a DNN.

  1. You need to first compute the kernel (NTK) corresponding to the architecture.
  2. Then you need to perform kernel regression with this kernel.

For really simple architectures like MLP, step 1 is very cheap, and step 2 is the bottleneck. Even then, if we are talking about CIFAR10 or MNIST, step 2 just involves some matrix inversion and multiplication of about 60k x 60k matrices, which isn't so bad. For slightly more advanced architecture like CNNs or the architectures in this repo, step 1 is relatively more expensive than step 2, but still do-able. (You can look at the neural tangents library for efficient implementation and some efficiency numbers). In both of these cases, if your dataset is relatively small (say 1k or even up to 10k points; for really optimized architectures you can probably go up to 100k), then the NTK can be faster to "train" than the DNN. However, if you go beyond this dataset size or if your architecture is more complicated, then DNN is probably faster to train.

Hope this helps!

thegregyang commented 4 years ago

Closing now, but feel free to reply if you have more questions.