Closed HornHehhf closed 3 years ago
Hey Hangfeng, I'm afraid that computing CNN kernels with pooling layers is in general very costly, so I don't have a good solution for you right now. Some specific thoughts:
1) 381T sounds like roughly the size of (10000 32 32)^2 float32 covariance matrix, which you need to compute as an intermediary step if your network has pooling layers. You need to use nt.batch
decorator on your kernel_fn
, and the batch_size
would need to be somewhere in the range of 10-100 depending on your CPU/GPU RAM.
2) I'm not certain why batch size of 200 runs out of memory (are you running it on CPU RAM?), but I suspect it may be due to JAX having to allocate 3-4X of the tensor size (instead of 2X) due to https://github.com/google/jax/issues/1273 https://github.com/google/jax/issues/1733, in which case there isn't much we can do (please upvote those issues!), and you would need to scale down the batch size even further. If this explanation is correct, I think a batch size of 130-150 should work.
3) If you can, please post your code snippet and perhaps I can see if there are any performance improvements possible. But in general it seems about right to me that the FC kernel takes minutes while CNN+Pooling kernel takes thousands of hours, since in intermediary layers, they effectively work with kernels of sizes 10000^2 and (10000 x 32 x 32)^2 respectively. Note that one middle ground is to not have any pooling layers and have a stax.Flatten()
layer at the top - in this case the effective kernel size will be 10000^2 x 32x32. [If you are seeing your errors with stax.Flatten()
and no pooling layers - please let me know, as in this case it might be a fixable bug]
4) FYI, under the hood CNN kernel computation (both with and without pooling layers) uses convolutions, which AFAIK are much more efficient on GPUs than CPUs, especially for some common settings (e.g. 3x3 kernels, SAME
padding). This might be a secondary reason why the computation is especially slow, and you might get noticeably better results with a GPU and a much smaller batch size (it will be very costly still though).
We'll definitely take note to prioritize CNN+pooling performance, but FYI I'm not aware of simple ways of improving it currently, so it might take a while.
Thanks for your quick reply. I want to double-check your current suggestions are the following two things:
BTW, another question is: Do you have any ideas about why the CNTK implementation in Arora's paper seems not so costly? Their model is not as generalizable as yours, but maybe their implementations can give your some insights on dealing with the CNN issue.
Thanks again for your quite reply. It is really helpful!
1) Yes, i.e. something like stax.serial(stax.Conv(...), stax.Relu(), stax.Conv(...), ..., stax.Flatten(), stax.Dense(10))
. If this does not work considerably faster and much less memory-hungry - let me know, this would mean there's a bug. It should still be slower than fully-connected kernel though.
2) Yes, the batch size will have to be ~O(10) if you use pooling, but might still be faster due to GPU efficiency.
3) That's definitely on my todo list! They seem really efficient, at the top of my head, suspect reasons:
Our current implementation computes the CNN kernel propagation as a sequence of convolutions with the identity matrix kernel, i.e. for a CNN layer with a 3x3 kernel we convolve the covariance matrix with
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
twice. So the effective operation has 3 summands (diagonal), but actual computation sums 3x3 = 9 summands, the whole kernel. I imagine (haven't studied their code yet) they might have implemented this efficiently directly in CUDA, while we are using the JAX/XLA primitives available, and among those this seems like the best solution. A while ago we did try some tricks to reduce the operations to 3x1 convolutions, but this only yielded us ~1.3X speedup (and not expected 3X in this case), so perhaps there is also some overhead related to using JAX/XLA and not CUDA directly.
google/jax#1273 google/jax#1733 - perhaps halving the memory footprint would allow us much bigger batch sizes and better performance.
Btw, one slightly silly way in which we are slow now, is that currently we always compute the whole kernel matrix, instead of only upper or lower diagonal (since it's symmetric). It's on our TODO list to fix, but in the meantime you could work this around by writing your own double-for-loop, where you loop over batches of inputs and construct only lower-diagonal matrix yourself.
This is a bit handwavy, I'll definitely need to benchmark/debug in more detail to pinpoint the reason precisely. Hope this helps, lmk if you have any other questions!
Yes, the CNN without pooling is much faster, which is acceptable to me. Thanks for your help.
@HornHehhf I just want to add to @romanngg's reply that CNTK code requires order O(1000) GPU hours and process through small batches to deal with large memory issue. Their custom CUDA kernel is more efficient that our library using JAX/XLA primitives at the moment, nonetheless inherently these kernels are compute intensive.
Got it, thanks a lot!
Btw, I just made a change to how CNN kernel is computed without pooling, it should give about ~25% speedup on GPU (no improvements to CNN w/ pooling though).
252ed85eb3c697bd634f4400d437711aa1bd9104
Good news: I found a hack that speeds up CNNs w/ pooling by >4X on NVIDIA GPUs! 100aface242be5cee01352e59a122f99680d65b8, should be there in NT >= 0.2.1.
We should now be comparable in performance to Arora et al, but as @jaehlee remarked, the kernel computation is still inherently costly. This is not yet extensively tested though, so please let us know if there are any issues!
I've added some benchmarks to https://github.com/google/neural-tangents#performance, you can use them to estimate how long your task should take.
One sad takeaway from the table though is that even a very beefed-up CPU is 40X slower than a single V100 on this task, which makes it especially ill-suited. I noticed that there is very low CPU utilization when doing single-channel CNNs (which we do), and filed a bug with the XLA team, hope they can help us with this!
Finally, if you're aiming for top performance on images, you probably do need pooling layers, so there is a tradeoff between speed (Flatten
) and task accuracy (GlobalAvgPool
). We discussed this phenomenon in https://arxiv.org/pdf/1810.05148.pdf (Figure 1, section 5.1).
Great, thanks for your notes. I want to double-check that the speed comparison between the CPU and GPU. For CNN with pooling, CPU is 40X slower than a single V100. How about the CNN without pooling? Does the speed difference still hold for the CNN without pooling? The time for computing the CNN without pooling on CPU is acceptable to me, but I still want to double-check the speed comparison between GPU and CPU for the CNN without pooling.
Yes, without pooling the ratio also seems to be in 30-40X.
Otherwise, CNN-Flatten seems about 1000X faster than CNN-Pool, which makes sense since the covariance tensor size in that case is 32 * 32 = 1024 times smaller.
Great, thanks very much!
Hi,
Thanks, Hangfeng