rapidsai / dask-cuda

Utilities for Dask and CUDA interactions
https://docs.rapids.ai/api/dask-cuda/stable/
Apache License 2.0
278 stars 89 forks source link

CPU Memory Usage for Tasks with CPU-GPU Transfer #1351

Open ilan-gold opened 2 weeks ago

ilan-gold commented 2 weeks ago

Issue + Reproducers

So I have an i/o job that reads in data to the CPU and passes to the GPU in a map_blocks call, and then uses CuPy downstream for a non-standard map-blocks call. Here is the reproducer minus the i/o:


from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask.array as da
import dask

import cupy as cp
import rmm
import numpy as np
from rmm.allocators.cupy import rmm_cupy_allocator
from cupyx.scipy import sparse as cp_sparse
from scipy import sparse

def set_mem():
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)

cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
client = Client(cluster)
client.run(set_mem)

M = 100_000
N = 4_000

def make_chunk():
    arr = np.random.random((M,N))
    chunk = cp.array(arr)
    del arr
    return chunk

arr = da.map_blocks(make_chunk, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64, chunks=((M,) * 50, (N,) * 1))
blocks = arr.to_delayed().ravel()

def __gram_block(block):
    return block.T @ block

gram_chunk_matrices = da.map_blocks(__gram_block, arr, chunks=((arr.shape[1],) * len(blocks) , (arr.shape[1],)), dtype=arr.dtype, meta=cp.array([]))
gram_chunk_matrices = gram_chunk_matrices.reshape(len(blocks), arr.shape[1], arr.shape[1])
gram_matrix = gram_chunk_matrices.sum(axis=0).compute()

This uses an unaccountable amount of CPU memory, on the order of 4-8 GB. But I have no idea why this is happening. I don't have any CPU memory that should be used here except the initial read. And when the job completes, dask still reports that it is holding on to 4GB (!) of memory. I see at most 2 tasks running with another 2-6 in memory. In total, the CPU memory being so high doesn't make sense since the individual numpy arrays are 320MB, so this should be at most 640MB (and even that seems high given how long they last on the CPU before I call del). I don't think this is a dask-memory-reporting issue because top shows the same amount of memory usage.

I also don't think this has to do with the computation I chose as:

def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices = doubled_matrices.reshape(len(blocks), M, N)

doubled_matrix = doubled_matrices.sum(axis=0).compute()

has the same issue, albeit with a warning about chunk sizes (although I'm not sure why I'm getting the warning since the reshape is right along the blocks). In any case, it's the same memory problem. Minus the reshape, and so minus that warning, same behavior:

def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

Some env info:

Python 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] on linux
dask==2024.2.1
dask-cuda==24.6.0
dask-cudf-cu12==24.6.0
distributed==2024.2.1
distributed-ucxx-cu12==0.38.0
cuda-python==12.5.0
cudf-cu12==24.6.0
cugraph-cu12==24.6.1
cupy-cuda12x==13.2.0
numpy==1.26.4

Some monitoring screenshots:

Screenshot 2024-06-21 at 16 19 07 Screenshot 2024-06-21 at 16 18 51

Addendum 1: I don't see this as spilling

I also don't think this is spilling because the GPU memory is not that high:

Screenshot 2024-06-24 at 10 48 59

That number is also basically correct: ((4000 * 4000) * 4 + (100_000 * 4000) * 4) < 2GB where the first 4000 * 4000 is from holding the sum(-partial) in memory and then 100_000 * 4000 is the input data.

Addendum 2: This is GPU specific

This behavior does not happen on CPU dask:

import dask.distributed as dd
import numpy as np
cluster = dd.LocalCluster(n_workers=1)
client = dd.Client(cluster)

M = 100_000
N = 4_000
def make_chunk():
    arr = np.random.random((M,N))
    return arr

arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 50, (N,) * 1))
def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

I see the memory use fluctuating, but the baseline of 1.5GB makes sense given the in-memory/processing stats I cited above. It also releases the memory at the end

Addendum 3: Not an rmm issue

I tried commenting out client.run(set_mem) and that also had no effect.

pentschev commented 2 weeks ago

In total, the CPU memory being so high doesn't make sense since the individual numpy arrays are 320MB, so this should be at most 640MB (and even that seems high given how long they last on the CPU before I call del)

Your chunk size is 100_000 * 4_000 * 8 bytes = 3.2GB, and not 320MB. That seems consistent with the memory utilization you're expecting. Am I misunderstanding anything here?

pentschev commented 2 weeks ago

That number is also basically correct: ((4000 * 4000) * 4 + (100_000 * 4000) * 4) < 2GB where the first 4000 * 4000 is from holding the sum(-partial) in memory and then 100_000 * 4000 is the input data.

Also where are you getting the * 4 from? np.random.random returns float64, which is 8 bytes, thus you're miscalculating by a factor of 2 here as well.

ilan-gold commented 2 weeks ago

@pentschev Apologies for the bad math, shameful on my part.

However, there are still some things I don't understand that are affecting us:

  1. There is still memory hanging around post-completion of the task.
  2. CPU-based dask seems to use much less memory.

Perhaps both of these are aberrations. I will keep investigating.

ilan-gold commented 2 weeks ago

CPU-based dask seems to use much less memory.

huh now this also seems to be the opposite. not sure what's up with that.

ilan-gold commented 2 weeks ago

Sorry, the reshape operation is costly.

So

import dask.distributed as dd
import numpy as np
cluster = dd.LocalCluster(n_workers=1)
client = dd.Client(cluster)

M = 100_000
N = 4_000
def make_chunk():
    arr = np.random.random((M,N))
    return arr

arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 50, (N,) * 1))
def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

uses less memory than GPU equivalent, where it peaks at a similar amount but then falls back faster to a near-0 amount. I must have copy-and-pasted the wrong cell.

pentschev commented 2 weeks ago
  • There is still memory hanging around post-completion of the task.

When you say "post-completion of the task", are you referring to the return of map_blocks/make_chunk or some other task? Note that del does not guarantee that the object will be immediately freed, only after garbage collection runs you can be sure of that.

2. CPU-based dask seems to use much less memory.

With your CPU version of the code I also see memory peaks at ~7GB on my end, which is similar to what I see in your original/GPU code. I'm not sure exactly what you're seeing on your end, but I cannot reproduce the CPU-only setup using much less memory as you're reporting.

ilan-gold commented 2 weeks ago

When you say "post-completion of the task", are you referring to the return of map_blocks/make_chunk or some other task? Note that del does not guarantee that the object will be immediately freed, only after garbage collection runs you can be sure of that.

I am saying after the final compute call when nothing else is happening, the memory that both dask and top report as still being held is about 4GB.

With your CPU version of the code I also see memory peaks at ~7GB on my end, which is similar to what I see in your original/GPU code. I'm not sure exactly what you're seeing on your end, but I cannot reproduce the CPU-only setup using much less memory as you're reporting.

Yup, I'm saying that the peak is the same, about 7GB, but it returns to near-0 GB during processing insteading of staying at 4GB as a baseline as on the GPU version.

pentschev commented 2 weeks ago

Yup, I'm saying that the peak is the same, about 7GB, but it returns to near-0 GB during processing insteading of staying at 4GB as a baseline as on the GPU version.

Measuring by eyeballing may be tricky here. I wouldn't be surprised if the GPU implementation seemingly doesn't drop as much because the processing is much faster and values simply do not update fast enough. With that said, unless you're measuring at high-frequency, I wouldn't trust too much the appearance that they are behaving differently.

ilan-gold commented 2 weeks ago

I wouldn't be surprised if the GPU implementation seemingly doesn't drop as much because the processing is much faster and values simply do not update fast enough

I would tend to agree here but top is also reporting the memory usage and I can see clearly workers dying on account of CPU memory usage in my non-contrived real-world use-case on the GPU

ilan-gold commented 2 weeks ago

I have considered just turning off the memory monitoring of dask but I am not sure what that will do if top is still reporting memory usage.

pentschev commented 2 weeks ago

I would tend to agree here but top is also reporting the memory usage and I can see clearly workers dying on account of CPU memory usage in my non-contrived real-world use-case on the GPU

In your real use case do you have controlled data sizes you're reading like you have in the code you shared here or does that vary across chunks? Is it possible that you're running out-of-memory because you really have too much data being read/stored in memory? Note that depending on your system setup, data access patterns, etc., it's possible that you end up with a different load pattern because GPU compute is completing faster than CPU was. One thing you may try is adding gc.collect() at certain points in your code to try and evict data, as del alone doesn't suffice for that purpose.

ilan-gold commented 2 weeks ago

In your real use case do you have controlled data sizes you're reading like you have in the code you shared here or does that vary across chunks?

Yup, it's not an exact match but the overhead should not be too high. I will check tomorrow to be sure, but we're talking on the order of a few extra megabytes per read. The chunk size is rather small. I will for sure check to be 100% sure tomorrow, though.

One thing you may try is adding gc.collect() at certain points in your code to try and evict data, as del alone doesn't suffice for that purpose.

I will also have to double check this to be 100% (you can see I am a bit careless sometimes), but I have definitely tried this as well to no avail.

I think what's strange still is the memory hanging around after job completion. I will also try an even dumber reproducer for that tomorrow, like one job, no chunks, just cpu-gpu, and see if the memory hangs around.

pentschev commented 2 weeks ago

I think what's strange still is the memory hanging around after job completion. I will also try an even dumber reproducer for that tomorrow, like one job, no chunks, just cpu-gpu, and see if the memory hangs around.

This generally means there is some circular reference that the garbage collector couldn't break and release. I haven't checked what happens to memory after your sample(s) above complete and the Dask remains alive, do you observe memory still resident in those samples too or just in your real code?

ilan-gold commented 2 weeks ago

do you observe memory still resident in those samples too or just in your real code?

Yes, the reproducer definitely does show the "bad behavior." I've checked the garbage collector and don't see anything too big hanging around but will have another look.

ilan-gold commented 2 weeks ago

So when I print out the leftover objects (after del), I get <class 'rmm._lib.device_buffer.DeviceBuffer'>. I suspect this to be something of an aberration though as without rmm nothing comes up but the hanging memory is still there.

def make_chunk():
    arr = np.random.random((M,N))
    chunk = cp.array(arr)
    del arr
    dask.distributed.print([type(obj) for obj in gc.get_objects() if (sys.getsizeof(obj) / 1000000) > 100])
    return chunk
ilan-gold commented 2 weeks ago

Ah, I also forgot another counterfactual. The following displays identical behavior:

def make_chunk():
    return np.random.random((M,N))
arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 15, (N,) * 1))
arr = arr.map_blocks(cp.array, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64)

followed by some operation like:

def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

So I'm not sure about garbage collection.

After that operation completes you see:

Screenshot 2024-06-26 at 10 59 14

and as confirmation, 3658116 is the PID for the worker:

Screenshot 2024-06-26 at 11 00 05

pentschev commented 1 week ago

I can reproduce that too, using your latest example I see commenting arr = arr.map_blocks(cp.array, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64) out does not cause that so it's almost definitely somehow related to CUDA or CuPy directly. I couldn't identify where that memory lives though, I tried reproducing something similar without Dask in a simple for loop (see below) but I can't see anything resembling the same memory leak.

for i in range(50):
    a = np.random.random((100_000, 4_000))
    b = cp.array(a)
    print(a)
    print(b)

While I can confirm the behavior on my end I don't think I'll have much time to push this further in the short-term. In any case I'll summarize my findings as you did above. I can see about 4GB of host memory still being held by the Dask worker after execution completes (see images in https://github.com/rapidsai/dask-cuda/issues/1351#issuecomment-2191190861 for examples), the code I ran is below:

from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask.array as da
import dask

import cupy as cp
import rmm
import numpy as np
from rmm.allocators.cupy import rmm_cupy_allocator

if __name__ == "__main__":
    def set_mem():
        rmm.reinitialize(managed_memory=True)
        cp.cuda.set_allocator(rmm_cupy_allocator)

    cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
    client = Client(cluster)
    client.run(set_mem)

    M = 100_000
    N = 4_000

    def make_chunk():
        return np.random.random((M,N))
    arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 15, (N,) * 1))

    # Commenting this line out, thus making the code CPU-only, prevents the 4GB from
    # being held by the worker.
    arr = arr.map_blocks(cp.array, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64)

    def __double_block(block):
            return block * 2

    doubled_matrices = da.map_blocks(__double_block, arr)
    doubled_matrices.sum(axis=0).compute()

    import time
    time.sleep(600)

Besides the above, commenting out the set_mem call and using LocalCluster(n_workers=1) instead of LocalCUDACluster also cause the same behavior, with that I conclude the issue is not RMM related and not Dask-CUDA specific, but instead Dask+GPU-related.

Nevertheless, it would be interesting to know what is your use case @ilan-gold , as this might help us prioritize this issue.

@charlesbluca @quasiben @rjzamora @wence- pinging you in case anybody has time or can think of ways to debug this further.

ilan-gold commented 1 week ago

Nevertheless, it would be interesting to know what is your use case @ilan-gold , as this might help us prioritize this issue.

We are reading off disk via hdf5 files and then doing downstream processing. I don't have a PR at the moment to show (it's somewhat based on https://github.com/scverse/rapids_singlecell/pull/179 but this is only really the processing and not the i/o, which is the problem).

The issue is really that we are CPU memory limited to 256GB. So this becomes a tradeoff between number of workers and memory as this problem increases with the number of workers.

rjzamora commented 1 week ago

... I conclude the issue is not RMM related and not Dask-CUDA specific, but instead Dask+GPU-related

I'm not so sure this behavior has anything to do with dask. It definitely seems like cupy is somehow holding on to a host-memory reference somewhere. As far as I can tell, creating and then deleting a cupy array (in the absence of dask) also leaves behind a residual host-memory footprint.

pentschev commented 1 week ago

I'm not so sure this behavior has anything to do with dask. It definitely seems like cupy is somehow holding on to a host-memory reference somewhere. As far as I can tell, creating and then deleting a cupy array (in the absence of dask) also leaves behind a residual host-memory footprint.

There's definitely some residual memory, but what I observed is in the ~50MB range, whereas in this case here we're seeing more like 4-5GB, which seems consistent with the original size of the array. However, as per my previous post I couldn't reproduce a more than something in the 50MB range of memory footprint with the following simple loop:

for i in range(50):
    a = np.random.random((100_000, 4_000))
    b = cp.array(a)
    print(a)
    print(b)

I'd happy to be proven wrong as it would be much easier to debug too, perhaps I missed some detail and did not really reproduce write equivalent code above.

ilan-gold commented 1 week ago

Just as some more evidence here of a leak:

Screenshot 2024-06-28 at 11 42 56 Screenshot 2024-06-28 at 11 42 49

Only the first layer of this graph reads into CPU memory. Everything else is GPU based via https://github.com/scverse/rapids_singlecell/pull/179. And yet, 14GB of memory usage. This gets worse both as time goes on with one worker, or if one uses more workers. It was 4GB before, and I would not be surprised if this OOMs from CPU usage before the job finishes.

ilan-gold commented 1 week ago

And the reported memory usage comes from summation. But this is done in GPU memory (or at least its supposed to).

Screenshot 2024-06-28 at 11 45 53
rjzamora commented 1 week ago

we're seeing more like 4-5GB, which seems consistent with the original size of the array

The global array size would be much larger than 4-5GB (or maybe I'm mistaken?). 4-5GB seems more like an accumulation of many smaller chunks, no?

I'm certainly not 100% sure that Dask is not the problem, I just suspect that dask is amplifying some strange cupy behavior. Another interesting note: If the cupy is used to generate the array in the original repro, the problem seems to go away?

ilan-gold commented 1 week ago

Another interesting note: If the cupy is used to generate the array in the original repro, the problem seems to go away?

Yes, but this makes sense, no? I'm specifically saying the CPU - GPU transfer is what I suspect to be problematic (and probably its garbage collection). Dask or not Dask.

The global array size would be much larger than 4-5GB (or maybe I'm mistaken?). 4-5GB seems more like an accumulation of many smaller chunks, no?

No idea!

rjzamora commented 1 week ago

Yes, but this makes sense, no? I'm specifically saying the CPU - GPU transfer is what I suspect to be problematic (and probably its garbage collection). Dask or not Dask.

Yes, I agree that this is expected, so maybe not so "interesting" :)

ilan-gold commented 6 days ago

When I run my real world job I see:

2024-07-04 16:08:01,630 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 22.85 GiB -- Worker memory limit: 32.00 GiB

So more evidence this is a memory leak. There is definitely no reason my job needs 22 GB much less 32GB

ilan-gold commented 6 days ago
Screenshot 2024-07-04 at 16 14 35

So here's something interesting. Only the block-id task has anything to do with CPU but the job is reporting GBs on GBs of memory. This could be on the GPU but the the memory usage is extremely high as well as can be seen in that bar.

Screenshot 2024-07-04 at 16 16 05
ilan-gold commented 6 days ago

Is it possible dask (i.e., maybe not something in this package) itself is secretly allocating CPU memory because it is not properly aware of GPU arrays?