google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Memory and running time issues for CNN #29

Closed HornHehhf closed 3 years ago

HornHehhf commented 4 years ago

Hi,

 I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. If I use a 2-layer feedforward NNs with reshaped input 3072, it took me about 3G memory and several minutes to compute the kernel.

However, if I use a simple CNN network (one layer CNN), it will output an error with "failed to allocate request 381T memory". I can only reduce the size of minibatch each time. But it will make the computing process quite slower. And this is just one-layer CNN, I expect it will cost more time for multilayer CNN. And even for one batch (100 images), it still costs much more time than the 2-layer feedforward NNs.

Another strange thing is that I expect that I should be able to compute the kernel matrix for batch size 200 (out of 10000) each time because the server has a memory of 394G.  But it is still out of memory (manually checked) after running several minutes and killed without error prompt.

So I am wondering how to use your tools to compute the kernel matrix for CNNs. It either costs too much memory or too much time in my end. Do you have any suggestions to deal with this issue?  I am not sure about your latent mechanism to compute the kernel for CNN. But I expect it shouldn't cost so much memory and run so slow, because [Arora et al' 2019](https://arxiv.org/pdf/1904.11955.pdf) compute the kernel for 21-layer CNN.

It is really a good tool but I hope that you can help with the CNN memory and running time issue.

Thanks, Hangfeng

romanngg commented 4 years ago

Hey Hangfeng, I'm afraid that computing CNN kernels with pooling layers is in general very costly, so I don't have a good solution for you right now. Some specific thoughts:

1) 381T sounds like roughly the size of (10000 32 32)^2 float32 covariance matrix, which you need to compute as an intermediary step if your network has pooling layers. You need to use nt.batch decorator on your kernel_fn, and the batch_size would need to be somewhere in the range of 10-100 depending on your CPU/GPU RAM.

2) I'm not certain why batch size of 200 runs out of memory (are you running it on CPU RAM?), but I suspect it may be due to JAX having to allocate 3-4X of the tensor size (instead of 2X) due to https://github.com/google/jax/issues/1273 https://github.com/google/jax/issues/1733, in which case there isn't much we can do (please upvote those issues!), and you would need to scale down the batch size even further. If this explanation is correct, I think a batch size of 130-150 should work.

3) If you can, please post your code snippet and perhaps I can see if there are any performance improvements possible. But in general it seems about right to me that the FC kernel takes minutes while CNN+Pooling kernel takes thousands of hours, since in intermediary layers, they effectively work with kernels of sizes 10000^2 and (10000 x 32 x 32)^2 respectively. Note that one middle ground is to not have any pooling layers and have a stax.Flatten() layer at the top - in this case the effective kernel size will be 10000^2 x 32x32. [If you are seeing your errors with stax.Flatten() and no pooling layers - please let me know, as in this case it might be a fixable bug]

4) FYI, under the hood CNN kernel computation (both with and without pooling layers) uses convolutions, which AFAIK are much more efficient on GPUs than CPUs, especially for some common settings (e.g. 3x3 kernels, SAME padding). This might be a secondary reason why the computation is especially slow, and you might get noticeably better results with a GPU and a much smaller batch size (it will be very costly still though).

We'll definitely take note to prioritize CNN+pooling performance, but FYI I'm not aware of simple ways of improving it currently, so it might take a while.

HornHehhf commented 4 years ago

Thanks for your quick reply. I want to double-check your current suggestions are the following two things:

  1. Try a CNN without pooling layer will occupy less memory.
  2. Use GPU instead of CPU with batches.

BTW, another question is: Do you have any ideas about why the CNTK implementation in Arora's paper seems not so costly? Their model is not as generalizable as yours, but maybe their implementations can give your some insights on dealing with the CNN issue.

Thanks again for your quite reply. It is really helpful!

romanngg commented 4 years ago

1) Yes, i.e. something like stax.serial(stax.Conv(...), stax.Relu(), stax.Conv(...), ..., stax.Flatten(), stax.Dense(10)). If this does not work considerably faster and much less memory-hungry - let me know, this would mean there's a bug. It should still be slower than fully-connected kernel though.

2) Yes, the batch size will have to be ~O(10) if you use pooling, but might still be faster due to GPU efficiency.

3) That's definitely on my todo list! They seem really efficient, at the top of my head, suspect reasons:

This is a bit handwavy, I'll definitely need to benchmark/debug in more detail to pinpoint the reason precisely. Hope this helps, lmk if you have any other questions!

HornHehhf commented 4 years ago

Yes, the CNN without pooling is much faster, which is acceptable to me. Thanks for your help.

jaehlee commented 4 years ago

@HornHehhf I just want to add to @romanngg's reply that CNTK code requires order O(1000) GPU hours and process through small batches to deal with large memory issue. Their custom CUDA kernel is more efficient that our library using JAX/XLA primitives at the moment, nonetheless inherently these kernels are compute intensive.

HornHehhf commented 4 years ago

Got it, thanks a lot!

romanngg commented 4 years ago

Btw, I just made a change to how CNN kernel is computed without pooling, it should give about ~25% speedup on GPU (no improvements to CNN w/ pooling though).

252ed85eb3c697bd634f4400d437711aa1bd9104

romanngg commented 4 years ago

Good news: I found a hack that speeds up CNNs w/ pooling by >4X on NVIDIA GPUs! 100aface242be5cee01352e59a122f99680d65b8, should be there in NT >= 0.2.1.

We should now be comparable in performance to Arora et al, but as @jaehlee remarked, the kernel computation is still inherently costly. This is not yet extensively tested though, so please let us know if there are any issues!

I've added some benchmarks to https://github.com/google/neural-tangents#performance, you can use them to estimate how long your task should take.

One sad takeaway from the table though is that even a very beefed-up CPU is 40X slower than a single V100 on this task, which makes it especially ill-suited. I noticed that there is very low CPU utilization when doing single-channel CNNs (which we do), and filed a bug with the XLA team, hope they can help us with this!

Finally, if you're aiming for top performance on images, you probably do need pooling layers, so there is a tradeoff between speed (Flatten) and task accuracy (GlobalAvgPool). We discussed this phenomenon in https://arxiv.org/pdf/1810.05148.pdf (Figure 1, section 5.1).

HornHehhf commented 4 years ago

Great, thanks for your notes. I want to double-check that the speed comparison between the CPU and GPU. For CNN with pooling, CPU is 40X slower than a single V100. How about the CNN without pooling? Does the speed difference still hold for the CNN without pooling? The time for computing the CNN without pooling on CPU is acceptable to me, but I still want to double-check the speed comparison between GPU and CPU for the CNN without pooling.

romanngg commented 4 years ago

Yes, without pooling the ratio also seems to be in 30-40X.

Otherwise, CNN-Flatten seems about 1000X faster than CNN-Pool, which makes sense since the covariance tensor size in that case is 32 * 32 = 1024 times smaller.

HornHehhf commented 4 years ago

Great, thanks very much!