rapidsai / cuml

cuML - RAPIDS Machine Learning Library
https://docs.rapids.ai/api/cuml/stable/
Apache License 2.0
4.05k stars 523 forks source link

[BUG] Multi GPU KMeans memory usage is 2x larger than expected. #5936

Closed tfeher closed 19 hours ago

tfeher commented 3 weeks ago

Describe the bug

The expected GPU memory usage of cuML's kmeans algorithm is (n_rows * n_cols + n_cluster*n_cols) * sizeof(MathT), where (n_rows, n_cols) is the shape of the input matrix. In parctice n_clusters is much less than n_rows, therefore the memory usage is expected to be only slightly larger than the input data size.

Here is a code to demonstrate memory usage of single GPU k-means

import rmm
upstream = rmm.mr.get_current_device_resource()
mr = rmm.mr.StatisticsResourceAdaptor(upstream)
rmm.mr.set_current_device_resource(mr)

from rmm.allocators.cupy import rmm_cupy_allocator
import cupy as cp
cp.cuda.set_allocator(rmm_cupy_allocator)
from cuml.cluster import KMeans
import numpy as np

dataset = cp.random.uniform(size=(4000000, 250), dtype=cp.float32)
print("Input dataset {}x{} {:6.1f} GB".format(dataset.shape[0], dataset.shape[1], dataset.size * dataset.dtype.itemsize / 1e9))
print("Peak memory usage after allocating input data:",mr.allocation_counts["peak_bytes"]/(1e9), "GB")

kmeans = KMeans(n_clusters=10000, max_iter=4, init='random')
kmeans.fit(dataset)
print("Peak memory after KMeans fit:",mr.allocation_counts["peak_bytes"]/(1e9), "GB")

Expected output

Input dataset 4000000x250    4.0 GB
Peak memory usage after allocating input data: 4.0 GB
Peak memory after KMeans fit: 4.142093072 GB

In contrast when using cuml.dask.cluster.KMeans the memory usage is twice as large.

Steps/Code to reproduce bug

# dask_experiment.py

import dask
import dask.array as da
from dask import config as cfg
cfg.set({'distributed.scheduler.worker-ttl': None})

import cupy as cp
import numpy as np
from dask_cuda import LocalCUDACluster
from dask.distributed import Client, wait
from raft_dask.common import Comms
from cuml.dask.cluster import KMeans

if __name__ == "__main__":
    n_gpus = 1
    cluster = LocalCUDACluster(n_workers=n_gpus, memory_limit=0) 
    client = Client(cluster)
    comms = Comms(client=client)
    comms.init()

    n_rows = 4000000*n_gpus
    n_cols = 250

    dataset = da.random.random((n_rows, n_cols), chunks=(n_rows/n_gpus, n_cols)).astype(np.float32)

    #@dask.delayed
    def to_gpu(x):
        return cp.asarray(x)

    dataset_gpu = da.map_blocks(to_gpu, dataset, dtype=cp.float32)
    dataset_gpu = dataset_gpu.persist()
    wait(dataset_gpu)

    print("Input dataset {}x{} {:6.1f} GB".format(dataset_gpu.shape[0], dataset_gpu.shape[1], dataset_gpu.size * dataset_gpu.dtype.itemsize / 1e9))

    print("starting clustering")
    kmeans = KMeans(n_clusters=1000, max_iter=4, init='random')
    kmeans.fit(dataset_gpu)

    comms.destroy()
    client.close()
    cluster.close()

We will use nvidia-smi to monitor memory usage

nvidia-smi -i 0 --query-gpu=index,timestamp,memory.used,memory.total --format=csv -l 1 & bkg_pid=$!; python dask_experiment.py; kill $bkg_pid

Output

[2] 450782
index, timestamp, memory.used [MiB], memory.total [MiB]
0, 2024/06/14 15:20:23.433, 0 MiB, 16384 MiB
...
0, 2024/06/14 15:20:41.438, 985 MiB, 16384 MiB
0, 2024/06/14 15:20:42.439, 4801 MiB, 16384 MiB
0, 2024/06/14 15:20:43.439, 4801 MiB, 16384 MiB
Input dataset 4000000x250    4.0 GB
starting clustering
0, 2024/06/14 15:20:44.439, 8883 MiB, 16384 MiB
1000000000
0, 2024/06/14 15:20:45.439, 8981 MiB, 16384 MiB

We can see from the output that after allocating 4.0 GB input array we had 4801 MiB memory usage, which went up to 8883 MiB after we have started clustering. This is not expected.

Expected behavior It is expected that the multi-GPU kmeans implementation has similar memory usage as the single GPU: it shall only need slightly large space than the local chunk of input data.

Environment details (please complete the following information): Checked with rapids 24.04 and 24.06 on various GPUs (V100, A100, A30).

cjnolet commented 3 weeks ago

@tfeher i believe we have an issue open for this already (and have for some time).

the problem is that kmeans 1) expects data in row major layout, and 2) can’t handle multiple partitions on each worker so it instead concatenates the partitions on each worker, thus duplicating the data footprint.

we can fix this, but it’s going to take each worker being able to support multiple partitions as input, and processing them in place.

cjnolet commented 3 weeks ago

Here is the issue I opened last year: https://github.com/rapidsai/cuml/issues/5212

(I’m looking for one that might have been opened earlier too. Maybe it was closed at some point).

tfeher commented 3 weeks ago

I think it is not exactly the same issue. In the example above, I carefully create the input data so that we have a single partition per worker. There is no copy needed.

Additionally, in a separate test, I have tried to skip all the dask input utils and directly call the KMeansMG wrapper (here is the script). This still results in double the memory usage.

cjnolet commented 3 weeks ago

I was initially going to say that we should verify whether or not your random input array is row major or Col major (because as you know the input utils will copy the data to transpose it) but if you were also encountering this problem directly in the c++ layer then that is indeed a problem.

Wow- that makes me wonder if there’s actually been multiple copies this whole time. That would explain a lot.

tfeher commented 3 weeks ago

Yes, it is row major, as expected by the c++ layer

print(dataset_gpu.blocks.ravel()[0].compute().flags)

  C_CONTIGUOUS : True
  F_CONTIGUOUS : False
  OWNDATA : True
tfeher commented 3 weeks ago

The memory duplication seems to happen just around the time when the clustering is finishes. Here is a time trace of the memory usage

nvidia-smi -i 0 --query-gpu=index,timestamp,memory.used --format=csv -l 1

index, timestamp, memory.used [MiB]
0, 2024/06/14 16:29:42.827, 679 MiB   # initial memory overhead
0, 2024/06/14 16:29:43.827, 4495 MiB  # input arrays allocated
0, 2024/06/14 16:29:44.828, 4495 MiB
0, 2024/06/14 16:29:45.828, 4519 MiB  # clustering
0, 2024/06/14 16:29:46.828, 4601 MiB
0, 2024/06/14 16:29:47.828, 4601 MiB
0, 2024/06/14 16:29:48.828, 4601 MiB
0, 2024/06/14 16:29:49.829, 4601 MiB
0, 2024/06/14 16:29:50.829, 4601 MiB
0, 2024/06/14 16:29:51.829, 4601 MiB
0, 2024/06/14 16:29:52.829, 4601 MiB
0, 2024/06/14 16:29:53.830, 4601 MiB
0, 2024/06/14 16:29:54.830, 8447 MiB   # clustering finishes
0, 2024/06/14 16:29:55.832, 703 MiB
cjnolet commented 3 weeks ago

Ooh, interesting. That makes me wonder if the output is somehow being duplicated on its way back to the user (in Dask, or when cuML is providing the properly formatted output). I admit I’m not sure how the cuml output types work in Dask at this point because I understand there was some refactoring of those internal APIs recently.

dantegd commented 3 weeks ago

On a quick audit of the code, I can't find where a copy is being done in this case quite yet. Shouldn't impact to this extent, but we also have a PR that fixes having all the labels_ of the data correctly assigned (#5931) which should increase memory usage a bit when keeping a trained model around, but even that hasn't been merged!

Interestingly, I see a different behavior in a t4:

0, 2024/06/15 00:00:14.221, 632 MiB
0, 2024/06/15 00:00:15.222, 632 MiB
0, 2024/06/15 00:00:16.222, 4448 MiB 
0, 2024/06/15 00:00:17.222, 4448 MiB # training begins
0, 2024/06/15 00:00:18.222, 4456 MiB 
0, 2024/06/15 00:00:19.222, 8630 MiB
0, 2024/06/15 00:00:20.223, 8630 MiB
0, 2024/06/15 00:00:21.223, 8630 MiB
0, 2024/06/15 00:00:22.223, 8660 MiB
0, 2024/06/15 00:00:23.223, 8226 MiB # clustering finishes
0, 2024/06/15 00:00:24.224, 382 MiB
dantegd commented 3 weeks ago

And changing initialization (scalable-k-means++), the duplication seems to be happening pretty much at the beginning for me in a T4 with 24.08:

0, 2024/06/15 00:15:38.020, 632 MiB, 15360 MiB
0, 2024/06/15 00:15:39.020, 632 MiB, 15360 MiB
0, 2024/06/15 00:15:40.021, 4448 MiB, 15360 MiB
0, 2024/06/15 00:15:41.021, 4448 MiB, 15360 MiB
0, 2024/06/15 00:15:42.021, 4456 MiB, 15360 MiB
Input dataset 4000000x250    4.0 GB
starting clustering
1000000000
0, 2024/06/15 00:15:43.021, 8546 MiB, 15360 MiB
0, 2024/06/15 00:15:44.021, 8634 MiB, 15360 MiB
0, 2024/06/15 00:15:45.022, 8636 MiB, 15360 MiB
0, 2024/06/15 00:15:46.022, 8636 MiB, 15360 MiB
0, 2024/06/15 00:15:47.022, 8636 MiB, 15360 MiB
0, 2024/06/15 00:15:48.022, 8638 MiB, 15360 MiB
0, 2024/06/15 00:15:49.022, 8638 MiB, 15360 MiB
0, 2024/06/15 00:15:50.023, 8638 MiB, 15360 MiB
0, 2024/06/15 00:15:51.023, 8638 MiB, 15360 MiB
0, 2024/06/15 00:15:52.023, 8640 MiB, 15360 MiB
0, 2024/06/15 00:15:53.023, 8640 MiB, 15360 MiB
0, 2024/06/15 00:15:54.023, 8640 MiB, 15360 MiB
0, 2024/06/15 00:15:55.024, 8640 MiB, 15360 MiB
0, 2024/06/15 00:15:56.024, 8640 MiB, 15360 MiB
0, 2024/06/15 00:15:57.024, 8642 MiB, 15360 MiB
0, 2024/06/15 00:15:58.024, 8642 MiB, 15360 MiB
0, 2024/06/15 00:15:59.025, 8642 MiB, 15360 MiB
0, 2024/06/15 00:16:00.025, 8642 MiB, 15360 MiB
0, 2024/06/15 00:16:01.025, 8642 MiB, 15360 MiB
0, 2024/06/15 00:16:02.025, 8642 MiB, 15360 MiB
0, 2024/06/15 00:16:03.025, 8642 MiB, 15360 MiB
0, 2024/06/15 00:16:04.026, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:05.026, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:06.026, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:07.026, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:08.027, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:09.027, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:10.027, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:11.027, 8644 MiB, 15360 MiB
0, 2024/06/15 00:16:12.028, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:13.028, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:14.028, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:15.028, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:16.029, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:17.029, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:18.029, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:19.029, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:20.029, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:21.030, 8648 MiB, 15360 MiB
0, 2024/06/15 00:16:22.030, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:23.030, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:24.030, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:25.031, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:26.031, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:27.031, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:28.031, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:29.032, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:30.032, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:31.032, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:32.032, 8682 MiB, 15360 MiB
0, 2024/06/15 00:16:33.033, 8650 MiB, 15360 MiB
0, 2024/06/15 00:16:34.033, 8630 MiB, 15360 MiB
0, 2024/06/15 00:16:35.033, 8630 MiB, 15360 MiB
0, 2024/06/15 00:16:36.033, 8630 MiB, 15360 MiB
0, 2024/06/15 00:16:37.033, 8568 MiB, 15360 MiB
0, 2024/06/15 00:16:38.034, 400 MiB, 15360 MiB
dantegd commented 3 weeks ago

Managed to track the issue to a part of our python input utilities. Still don't know why it's behaving this way specifically for dask scenarios, but should be able to triage and create soon.

dantegd commented 3 weeks ago

Triaged it to an issue on concatenation, submitted a PR working on the fix.

tfeher commented 3 weeks ago

Thansk @dantegd for the fix! After this I believe there is still a duplication happening as I describe above. We have a suspicion that this is related to how we pass the input data to predict after fit is finished. @achirkin will test a fix for that and let us know.

dantegd commented 3 weeks ago

@tfeher @achirkin I thought that could've been the case as well, but I was still seeing the duplication even when removing that call entirely (and not calculating the labels).

After the fix of 5937 I cannot currently see any copies in the example in two machines:

0, 2024/06/17 11:51:00.763, 634 MiB
0, 2024/06/17 11:51:01.264, 634 MiB
0, 2024/06/17 11:51:01.764, 4450 MiB
0, 2024/06/17 11:51:02.264, 4450 MiB
0, 2024/06/17 11:51:02.764, 4450 MiB
0, 2024/06/17 11:51:03.264, 4450 MiB
0, 2024/06/17 11:51:03.765, 4458 MiB
0, 2024/06/17 11:51:04.265, 4458 MiB
0, 2024/06/17 11:51:04.771, 4716 MiB
0, 2024/06/17 11:51:05.271, 4716 MiB
0, 2024/06/17 11:51:05.771, 4716 MiB
0, 2024/06/17 11:51:06.271, 4716 MiB
0, 2024/06/17 11:51:06.771, 4396 MiB
0, 2024/06/17 11:51:07.272, 384 MiB
0, 2024/06/17 11:51:07.772, 382 MiB
0, 2024/06/17 11:51:08.272, 374 MiB

do you still see the double memory when testing the fix of that PR?

achirkin commented 3 weeks ago

I've added a bunch of print statements with memGetInfo in the reproducer script and in a few cuml files (kmeans_mg.pyx, kmeans.pyx, distributed/protocol/rmm.py), here's the output after #5937 is applied (I also disabled the cupy pool allocator using cp.cuda.set_allocator(None)):

Before da.random.random': free = 21105.410048 MB
Before 'to_gpu': free = 21105.410048 MB
to_gpu enter: free = 21105.410048 MB
to_gpu exit: free = 21105.410048 MB
After 'to_gpu': free = 21105.410048 MB
After 'persist': free = 21105.410048 MB
to_gpu enter: free = 21096.890368 MB
to_gpu exit: free = 17095.327744 MB
After 'wait': free = 17095.327744 MB
Input dataset 4000000x250    4.0 GB
starting clustering
After `kmeans = KMeans`: free = 17095.327744 MB
Constructed X_m: free = 17028.21888 MB
Constructed sample_weight_m: free = 17011.441664 MB
Constructed cluster_centers_: free = 17011.441664 MB
1000000000
Going to do the final step: predict labels: free = 17011.441664 MB
predict enter: free = 17011.441664 MB   # <-- this suggest the call to[predict(X, ...)] doesn't do extra allocations
predict exit: free = 16994.664448 MB
All done; cleaning up: free = 16994.664448 MB
After `kmeans.fit`: free = 17078.550528 MB

From this output it looks like there's no unnecessary allocation on cuml side anymore.

However, the script still fails with OOM if the allocation size is more that half of available memory. Here's full output (12GB allocation):

Before da.random.random': free = 21121.466368 MB
Before 'to_gpu': free = 21121.466368 MB
to_gpu enter: free = 21121.466368 MB
to_gpu exit: free = 21121.466368 MB
After 'to_gpu': free = 21121.466368 MB
After 'persist': free = 21121.466368 MB
to_gpu enter: free = 21049.704448 MB
to_gpu exit: free = 9048.424448 MB
dask_deserialize_rmm_device_buffer: attempt to create buffer, size=12000.0 MB / avail = 8985.64096 MB
After 'wait': free = 8985.64096 MB
Input dataset 12000000x250   12.0 GB
starting clustering
After `kmeans = KMeans`: free = 8985.64096 MB
2024-06-17 17:41:28,517 - distributed.worker - ERROR - std::bad_alloc: out_of_memory: CUDA error at: $CONDA_PREFIX/include/rmm/mr/device/cuda_memory_resource.hpp:60: cudaErrorMemoryAllocation out of memory
Traceback (most recent call last):
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 184, in __getitem__
    return self.fast[key]
           ~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/common.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/lru.py", line 117, in __getitem__
    result = self.d[key]
             ~~~~~~^^^^^
KeyError: ('random_sample-to_gpu-c77bbaa15027e6f58fe7f9248373b3e6', 0, 0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 219, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 1940, in handle_stimulus
    super().handle_stimulus(*stims)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 3719, in handle_stimulus
    instructions = self.state.handle_stimulus(*stims)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1345, in handle_stimulus
    instructions += self._transitions(recs, stimulus_id=stim.stimulus_id)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2693, in _transitions
    process_recs(recommendations.copy())
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2687, in process_recs
    a_recs, a_instructions = self._transition(
                             ^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2599, in _transition
    recs, instructions = func(self, ts, *args, stimulus_id=stimulus_id)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1943, in _transition_memory_released
    recs, instructions = self._transition_generic_released(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1873, in _transition_generic_released
    self._purge_state(ts)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1476, in _purge_state
    if ts.key in self.data:
       ^^^^^^^^^^^^^^^^^^^
  File "<frozen _collections_abc>", line 780, in __contains__
  File "$CONDA_PREFIX/lib/python3.11/site-packages/dask_cuda/device_host_file.py", line 270, in __getitem__
    return self.device_buffer[key]
           ~~~~~~~~~~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/common.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 186, in __getitem__
    return self.slow_to_fast(key)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 153, in slow_to_fast
    value = self.slow[key]
            ~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/func.py", line 55, in __getitem__
    return self.load(self.d[key])
           ^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/nvtx/nvtx.py", line 116, in inner
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/dask_cuda/device_host_file.py", line 156, in host_to_device
    return deserialize(s.header, s.frames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/serialize.py", line 449, in deserialize
    return loads(header, frames)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/serialize.py", line 66, in dask_loads
    return loads(header["sub-header"], frames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/cupy.py", line 65, in dask_deserialize_cupy_ndarray
    frames = [dask_deserialize_cuda_buffer(header, frames)]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/rmm.py", line 48, in dask_deserialize_rmm_device_buffer
    buf = rmm.DeviceBuffer(ptr=ptr, size=size)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "device_buffer.pyx", line 98, in rmm._lib.device_buffer.DeviceBuffer.__cinit__
MemoryError: std::bad_alloc: out_of_memory: CUDA error at: $CONDA_PREFIX/include/rmm/mr/device/cuda_memory_resource.hpp:60: cudaErrorMemoryAllocation out of memory
2024-06-17 17:41:28,518 - distributed.scheduler - WARNING - Removing worker 'tcp://127.0.0.1:35693' caused the cluster to lose already computed task(s), which will be recomputed elsewhere: {('to_gpu-c77bbaa15027e6f58fe7f9248373b3e6', 0, 0)} (stimulus_id='handle-worker-cleanup-1718638888.5184376')
2024-06-17 17:41:28,519 - distributed.scheduler - ERROR - broadcast to tcp://127.0.0.1:35693 failed: CommClosedError: Address removed.
Traceback (most recent call last):
  File "$WORKSPACE/cuml/dask_experiment.py", line 53, in <module>
    kmeans.fit(dataset_gpu)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/cuml/internals/memory_utils.py", line 87, in cupy_rmm_wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/cuml/dask/cluster/kmeans.py", line 159, in fit
    comms.init(workers=data.workers)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/raft_dask/common/comms.py", line 195, in init
    worker_info = self.worker_info(self.worker_addresses)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/raft_dask/common/comms.py", line 160, in worker_info
    ranks = _func_worker_ranks(self.client, workers)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/raft_dask/common/comms.py", line 719, in _func_worker_ranks
    nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/client.py", line 2991, in run
2024-06-17 17:41:28,520 - tornado.application - ERROR - Exception in callback functools.partial(<bound method IOLoop._discard_future_result of <tornado.platform.asyncio.AsyncIOMainLoop object at 0x71205ef051d0>>, <Task finished name='Task-7' coro=<Worker.handle_scheduler() done, defined at $CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py:203> exception=MemoryError('std::bad_alloc: out_of_memory: CUDA error at: $CONDA_PREFIX/include/rmm/mr/device/cuda_memory_resource.hpp:60: cudaErrorMemoryAllocation out of memory')>)
Traceback (most recent call last):
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 184, in __getitem__
    return self.fast[key]
           ~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/common.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/lru.py", line 117, in __getitem__
    result = self.d[key]
             ~~~~~~^^^^^
KeyError: ('random_sample-to_gpu-c77bbaa15027e6f58fe7f9248373b3e6', 0, 0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "$CONDA_PREFIX/lib/python3.11/site-packages/tornado/ioloop.py", line 750, in _run_callback
    ret = callback()
          ^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/tornado/ioloop.py", line 774, in _discard_future_result
    future.result()
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 206, in wrapper
    return await method(self, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 1302, in handle_scheduler
    await self.handle_stream(comm)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/core.py", line 1053, in handle_stream
    handler(**merge(extra, msg))
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 1925, in _
    self.handle_stimulus(event)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 219, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker.py", line 1940, in handle_stimulus
    super().handle_stimulus(*stims)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 3719, in handle_stimulus
    instructions = self.state.handle_stimulus(*stims)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1345, in handle_stimulus
    instructions += self._transitions(recs, stimulus_id=stim.stimulus_id)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2693, in _transitions
    process_recs(recommendations.copy())
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2687, in process_recs
    a_recs, a_instructions = self._transition(
                             ^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 2599, in _transition
    recs, instructions = func(self, ts, *args, stimulus_id=stimulus_id)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1943, in _transition_memory_released
    recs, instructions = self._transition_generic_released(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1873, in _transition_generic_released
    self._purge_state(ts)
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/worker_state_machine.py", line 1476, in _purge_state
    if ts.key in self.data:
       ^^^^^^^^^^^^^^^^^^^
  File "<frozen _collections_abc>", line 780, in __contains__
  File "$CONDA_PREFIX/lib/python3.11/site-packages/dask_cuda/device_host_file.py", line 270, in __getitem__
    return self.device_buffer[key]
           ~~~~~~~~~~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/common.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 186, in __getitem__
    return self.slow_to_fast(key)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/buffer.py", line 153, in slow_to_fast
    value = self.slow[key]
            ~~~~~~~~~^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/zict/func.py", line 55, in __getitem__
    return self.load(self.d[key])
           ^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/nvtx/nvtx.py", line 116, in inner
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/dask_cuda/device_host_file.py", line 156, in host_to_device
    return deserialize(s.header, s.frames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/serialize.py", line 449, in deserialize
    return loads(header, frames)
           ^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/serialize.py", line 66, in dask_loads
    return loads(header["sub-header"], frames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/cupy.py", line 65, in dask_deserialize_cupy_ndarray
    frames = [dask_deserialize_cuda_buffer(header, frames)]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/protocol/rmm.py", line 48, in dask_deserialize_rmm_device_buffer
    buf = rmm.DeviceBuffer(ptr=ptr, size=size)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "device_buffer.pyx", line 98, in rmm._lib.device_buffer.DeviceBuffer.__cinit__
MemoryError: std::bad_alloc: out_of_memory: CUDA error at: $CONDA_PREFIX/include/rmm/mr/device/cuda_memory_resource.hpp:60: cudaErrorMemoryAllocation out of memory
    return self.sync(
           ^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/utils.py", line 358, in sync
    return sync(
           ^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/utils.py", line 434, in sync
    raise error
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/utils.py", line 408, in f
    result = yield future
             ^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/tornado/gen.py", line 766, in run
    value = future.result()
            ^^^^^^^^^^^^^^^
  File "$CONDA_PREFIX/lib/python3.11/site-packages/distributed/client.py", line 2896, in _run
    raise exc
distributed.comm.core.CommClosedError: Address removed.
2024-06-17 17:41:32,841 - distributed.nanny - WARNING - Worker process still alive after 4.0 seconds, killing
^C
dantegd commented 3 weeks ago

Haven't had time to try it, but given this line

KeyError: ('random_sample-to_gpu-c77bbaa15027e6f58fe7f9248373b3e6', 0, 0)

I wonder if the error would still happen similarly if say we use cuml.dask make blobs? Will make some time to try tomorrow if no one gets around to it.

achirkin commented 3 weeks ago

I've tried to replace the generation code as you suggested

    # dataset = da.random.random((n_rows, n_cols), chunks=(n_rows/n_gpus, n_cols)).astype(np.float32)
    dataset = cuml.dask.datasets.make_blobs(n_samples=n_rows, n_features=n_cols, n_parts=n_gpus, order='C', dtype=np.float32)[0]

This doesn't make the difference w.r.t. the memory usage. I believe this KeyError and the extra allocation happens after data generation, somewhere behind the dask.array.map_blocks.

The error you showed occurs in zict and is the expected behavior of the zict.Buffer LRU eviction logic. In my tests, this logic triggers exactly when the amount of free GPU memory is less than needed for one more allocation of the data (i.e. with 12GB setting; when I set the data size to 4GB, eviction does not happen). Actually, I'm not sure what is the fast and slow memories in this code, as well as what exactly triggers the eviction. But it looks like this eviction does not entail destruction/freeing of the array allocated on GPU; at the same time, self.slow_to_fast function (in the KeyError handler above) does try to allocate a new array (rmm.DeviceBuffer) and fail.

tfeher commented 3 weeks ago

@achirkin, above you write:

predict enter: free = 17011.441664 MB # <-- this suggest the call to[predict(X, ...)] doesn't do extra allocations Above you write: From this output it looks like there's no unnecessary allocation on cuml side anymore.

This is because KMeans takes already GPU data. Could you confirm what happens with numpy input data? (E.g just remove the to_gpucall).

(This is independent of the zict error)

achirkin commented 3 weeks ago

Indeed, in this case it allocates another copy of the dataset on GPU for a short time during call to predict, but Dante has already added the fix you suggested for this in #5937.

tfeher commented 3 weeks ago

Indeed, in this case it allocates another copy of the dataset on GPU for a short time during call to predict, but Dante has already added the fix you suggested for this in #5937.

I missed that. Thanks Artem for confirming, and thanks Dante for the fix!

tfeher commented 19 hours ago

Fixed by https://github.com/rapidsai/cuml/pull/5937