Open dllu opened 1 year ago
@dllu thanks for reporting and suggestions. I have a remark concerning your input numbers which look rather excessive to me as well.
Conv2d fails to allocate cpu memory in your example trying to allocate 24882763560000 / 1024 / 1024 / 1024 ~= 23173.87 GB
of RAM. This size, i.e. H W Kx Ky dtype is an internal representation for efficient convolution computation, see https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold (im2col).
Even if we implemented gaussian blur as separable convolution, we would still need to allocate H * W * Kx * dtype
of memory and in your case it would be 16860 * 10750 * 131 * 8 / 1024 / 1024 / 1024 ~= 176.90 GB
of RAM.
Even if we implemented gaussian blur as separable convolution, we would still need to allocate
H * W * Kx * dtype
of memory and in your case it would be16860 * 10750 * 131 * 8 / 1024 / 1024 / 1024 ~= 176.90 GB
of RAM.
True. But a savings of 131x is still not bad. We could also implement Gaussian blur with O(H * W) memory usage by doing an iterated box blur, which can be computed by using the running sum. This can still be sufficiently parallelized across each row/column.
When Kx >> log(N)
we can also opt for cuFFT.
+1 . torchvision.transforms.GaussianBlur seems significantly slower to me than the opencv equivalent.
Are there any alternatives that work with cuda tensors and that are more efficient? Thanks
🐛 Describe the bug
The Gaussian blur implementation here: https://github.com/pytorch/vision/blob/7d2acaa7d7fc600fa08fca18e9230f8651147025/torchvision/transforms/_functional_tensor.py#L746
is slow and consumes excessive memory because it allocates a $k \times k$ kernel and uses Conv2D.
As such, trying to gaussian blur a modestly-sized image fails:
You can see that the
conv2d
implementation is trying to allocate excessive memory because it is trying to allocate $O(n_h n_w k_h k_w)$ memory for convolving a $n_h \times n_w$ matrix with a $k_h \times k_w$ matrix. Notice that $16860 \times 10750 \times 131 \times 131 \times 8 = 24882763560000$.However, Gaussian blur is a separable convolution so it can be done as two convolutions, with a $k\times 1$ kernel and a $1\times k$ kernel. This would immensely speed it up. Furthermore, we could technically convolve each column and each row separately, further reducing memory usage, but this would come at a cost of parallelization.
random slide I found by googling: http://www-edlab.cs.umass.edu/~smaji/cmpsci370/slides/hh/lec02_hh_advanced_edges.pdf
We could implement it something like this:
(I haven't tested this so idk if it works but it should be something like that anyway)
We can also introduce an if-statement in for small kernel sizes where the memory usage is deemed insignificant relative to the small overhead of doing two convolutions.
Also, Gaussian blur can be implemented efficiently as iterated box blur or as an FFT.
Versions