rapidsai / dask-cuda

Utilities for Dask and CUDA interactions
https://docs.rapids.ai/api/dask-cuda/stable/
Apache License 2.0
278 stars 89 forks source link

Add cli option to enable pytorch to use same memory pool as rapids. #1281

Open VibhuJawa opened 7 months ago

VibhuJawa commented 7 months ago

Currently we need to below to set rmm to use pytorch pool on a dask-cuda cluster. We should do this via a cli

# Making PyTorch use the same memory pool as RAPIDS.
def _set_torch_to_use_rmm():
    """
    This function sets up the pytorch memory pool to be the same as the RAPIDS memory pool.
    This helps avoid OOM errors when using both pytorch and RAPIDS on the same GPU.
    See article:
    https://medium.com/rapids-ai/pytorch-rapids-rmm-maximize-the-memory-efficiency-of-your-workflows-f475107ba4d4
    """
    import torch
    from rmm.allocators.torch import rmm_torch_allocator
    torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

_set_torch_to_use_rmm()
client.run(_set_torch_to_use_rmm)
VibhuJawa commented 4 months ago

@quasiben , Wondering if we have opinion on this ? Happy to do a PR here to make life easier for folks like me.

CC: @alexbarghi-nv, @jnke2016 who have seen customer problems around a similar setup.

pentschev commented 4 months ago

I have no objections to this, my only suggestion would be to make this a generic extensible option where we can then specify which libraries to set RMM as memory manager for, something like this:

--set-rmm-allocator=torch,another_future_library,...

Do you think that makes sense? @VibhuJawa if you want to get started on a PR for this I'm happy to help addressing any issues you may find along the way.