Open eldarkurtic opened 2 years ago
cc @zasdfgbnm, 1d sort should just call cub directly?
I am not surprised by the memory usage. 1d sort calls segmented_sort_pairs_by_full_sort
:
https://github.com/pytorch/pytorch/blob/595a51b951f6dbe58c73390d910c6644d6074c82/aten/src/ATen/native/cuda/Sort.cu#L253
Whose theoretical memory usage could be:
== 400,000,000 * 4bytes = 1.6GB
output values
== 400,000,000 * 8bytes = 3.2GB
output indices
== 400,000,000 * 8bytes = 3.2GB
indices_and_segment
== 400,000,000 * 8bytes = 3.2GB
indices_and_segment2
>= 400,000,000 * 8bytes = 3.2GB
cub temporary space
@zasdfgbnm given that I only need k-th value in this huge tensor, do you know if there is any other way to get it?
What I've tried so far:
topk
and sort
are both super fast, but have the issue with huge memory allocationskthvalue
has almost no additional memory overhead, but is super slow (20x slower than sort
, documented here: https://github.com/pytorch/pytorch/issues/75599)I would recommend you to try if topk
will work for you for now. It was just improved this year by @yueyericardo large slice sizes, and there is another PR https://github.com/pytorch/pytorch/pull/74267 to further improve topk
.
I think @yueyericardo's kernel could apply to kthvalue as well. So it is possible that kthvalue
could potentionally be as fast as topk
. I can take a look at implementing this when I get some spare cycles.
Oh, sorry, I misread your message. You were saying that topk
had the same huge memory allocation? I think this is because topk
was using sort for that case. And https://github.com/pytorch/pytorch/pull/74267 should resolve this issue because the sort path are removed in this PR. (you need CUDA 11.6 or newer in able to use this PR)
For now, if your k
is small and if you don't need the index of the k
th value, you could do some workaround like:
tensor = torch.rand(400_000_000, device='cuda:0')
tensor = tensor.reshape(400, 1_000_000)
tensor.topk(k=200, dim=-1).values.reshape(-1).topk(k=200, dim=0).values
@zasdfgbnm thanks a lot for useful suggestions! I will definitely try both: the new commits and cheat to avoid calling sort
within the topk
with older version.
🐛 Describe the bug
As mentioned in the title, when sorting huge tensors on a GPU-device the peak GPU-memory usage is extremely huge. For example:
The first summary (i.e. after allocating this huge tensor):
and then the second summary after sorting:
Notice the ~18GB peak memory usage in the second summary for sorting a tensor that occupies ~1.5GB.
Versions
Collecting environment information... PyTorch version: 1.11.0 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 10 (buster) (x86_64) GCC version: (Debian 8.3.0-6) 8.3.0 Clang version: Could not collect CMake version: version 3.13.4 Libc version: glibc-2.28
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-4.19.0-19-amd64-x86_64-with-glibc2.17 Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: GeForce RTX 3090
Nvidia driver version: 460.73.01 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [conda] blas 1.0 mkl [conda] cudatoolkit 11.3.1 h2bc3f7f_2 [conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2021.4.0 h06a4308_640 [conda] mkl-service 2.4.0 py38h7f8727e_0 [conda] mkl_fft 1.3.1 py38hd3c417c_0 [conda] mkl_random 1.2.2 py38h51133e4_0 [conda] mypy-extensions 0.4.3 pypi_0 pypi [conda] numpy 1.22.3 pypi_0 pypi [conda] numpy-base 1.21.2 py38h79a1101_0 [conda] pytorch 1.11.0 py3.8_cuda11.3_cudnn8.2.0_0 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torchvision 0.12.0 py38_cu113 pytorch
cc @ngimel