pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.99k stars 6.92k forks source link

Gaussian blur slow and consumes excessive memory #7413

Open dllu opened 1 year ago

dllu commented 1 year ago

🐛 Describe the bug

The Gaussian blur implementation here: https://github.com/pytorch/vision/blob/7d2acaa7d7fc600fa08fca18e9230f8651147025/torchvision/transforms/_functional_tensor.py#L746

is slow and consumes excessive memory because it allocates a $k \times k$ kernel and uses Conv2D.

As such, trying to gaussian blur a modestly-sized image fails:

ipdb> from torchvision.transforms.functional import gaussian_blur
ipdb> population_smoothed = gaussian_blur(population_hi_res_gpu[None], kernel_size).squeeze(0).cpu().numpy()
*** RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 24882763560000 bytes. Error code 12 (Cannot allocate memory)
ipdb> kernel_size
131
ipdb> population_hi_res_gpu.shape
torch.Size([16860, 10750])
ipdb> population_hi_res_gpu.dtype
torch.float64

You can see that the conv2d implementation is trying to allocate excessive memory because it is trying to allocate $O(n_h n_w k_h k_w)$ memory for convolving a $n_h \times n_w$ matrix with a $k_h \times k_w$ matrix. Notice that $16860 \times 10750 \times 131 \times 131 \times 8 = 24882763560000$.

However, Gaussian blur is a separable convolution so it can be done as two convolutions, with a $k\times 1$ kernel and a $1\times k$ kernel. This would immensely speed it up. Furthermore, we could technically convolve each column and each row separately, further reducing memory usage, but this would come at a cost of parallelization.

image

random slide I found by googling: http://www-edlab.cs.umass.edu/~smaji/cmpsci370/slides/hh/lec02_hh_advanced_edges.pdf

We could implement it something like this:

def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    _assert_image_tensor(img)

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    for kernel in (kernel1d_x, kernel1d_y):
        kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

        img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])

        # padding = (left, right, top, bottom)
        padding = [kernel.shape[0] // 2, kernel.shape[0] // 2, kernel.shape[1] // 2, kernel.shape[1] // 2]
        img = torch_pad(img, padding, mode="reflect")
        img = conv2d(img, kernel, groups=img.shape[-3])

        img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img

(I haven't tested this so idk if it works but it should be something like that anyway)

We can also introduce an if-statement in for small kernel sizes where the memory usage is deemed insignificant relative to the small overhead of doing two convolutions.

Also, Gaussian blur can be implemented efficiently as iterated box blur or as an FFT.

Versions

Collecting environment information...
PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 14 2022, 12:59:47)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Quadro M4000
Nvidia driver version: 510.108.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          56
On-line CPU(s) list:             0-55
Thread(s) per core:              2
Core(s) per socket:              14
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           79
Model name:                      Intel(R) Xeon(R) CPU E5-2660 v4 @ 2.00GHz
Stepping:                        1
CPU MHz:                         1200.000
CPU max MHz:                     3200.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        3990.84
L1d cache:                       896 KiB
L1i cache:                       896 KiB
L2 cache:                        7 MiB
L3 cache:                        70 MiB
NUMA node0 CPU(s):               0-13,28-41
NUMA node1 CPU(s):               14-27,42-55
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.4
[pip3] torch==1.13.0
[conda] Could not collect
vfdev-5 commented 1 year ago

@dllu thanks for reporting and suggestions. I have a remark concerning your input numbers which look rather excessive to me as well. Conv2d fails to allocate cpu memory in your example trying to allocate 24882763560000 / 1024 / 1024 / 1024 ~= 23173.87 GB of RAM. This size, i.e. H W Kx Ky dtype is an internal representation for efficient convolution computation, see https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold (im2col).

Even if we implemented gaussian blur as separable convolution, we would still need to allocate H * W * Kx * dtype of memory and in your case it would be 16860 * 10750 * 131 * 8 / 1024 / 1024 / 1024 ~= 176.90 GB of RAM.

dllu commented 1 year ago

Even if we implemented gaussian blur as separable convolution, we would still need to allocate H * W * Kx * dtype of memory and in your case it would be 16860 * 10750 * 131 * 8 / 1024 / 1024 / 1024 ~= 176.90 GB of RAM.

True. But a savings of 131x is still not bad. We could also implement Gaussian blur with O(H * W) memory usage by doing an iterated box blur, which can be computed by using the running sum. This can still be sufficiently parallelized across each row/column.

When Kx >> log(N) we can also opt for cuFFT.

blutjens commented 8 months ago

+1 . torchvision.transforms.GaussianBlur seems significantly slower to me than the opencv equivalent.

montmejat commented 5 months ago

Are there any alternatives that work with cuda tensors and that are more efficient? Thanks