rapidsai / rmm

RAPIDS Memory Manager
https://docs.rapids.ai/api/rmm/stable/
Apache License 2.0
446 stars 188 forks source link

[DOC] Example of RMM Python API for DDP distributed training? #1560

Closed GHGmc2 closed 1 month ago

GHGmc2 commented 2 months ago

On single device, we can init RMM with

import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

How about distributed training with DDP on 32 cards? Is there an example for that? Thanks!

wence- commented 2 months ago

RMM has no concept of distributed memory parallelism built in, nor does it need to.

What you need to arrange is that the different ranks in your process correctly select the current cuda device, and then rmm.reinitialize should "just work". If you can show an example where it doesn't we would like to know.

You need to be aware of how to use RMM with multiple devices: https://github.com/rapidsai/rmm/?tab=readme-ov-file#multiple-devices

This is all easiest if each distributed process selects exactly one GPU.

GHGmc2 commented 2 months ago

RMM has no concept of distributed memory parallelism built in, nor does it need to.

What you need to arrange is that the different ranks in your process correctly select the current cuda device, and then rmm.reinitialize should "just work". If you can show an example where it doesn't we would like to know.

You need to be aware of how to use RMM with multiple devices: https://github.com/rapidsai/rmm/?tab=readme-ov-file#multiple-devices

This is all easiest if each distributed process selects exactly one GPU.

Thanks!

GHGmc2 commented 2 months ago

@wence- I try to reproduce the error I came across, Can you help on that? Thanks!

import rmm from rmm.allocators.torch import rmm_torch_allocator import torch

local_rank = int(os.environ.get("LOCAL_RANK", 0)) rmm.reinitialize(pool_allocator=True, devices=local_rank) torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

import torch.distributed as dist import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP

class ToyModel(nn.Module): def init(self): super(ToyModel, self).init() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5)

def forward(self, x):
    return self.net2(self.relu(self.net1(x)))

def setup(rank, world_size): dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

def cleanup(): dist.destroy_process_group()

def demo_basic(rank, world_size): setup(rank, world_size) print(f"Start running basic DDP example on rank {rank}.")

device_id = rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)

model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_id)
loss_fn(outputs, labels).backward()
optimizer.step()

print(f"Finish running basic DDP example on rank {rank}.")
cleanup()

if name == "main": rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 0)) demo_basic(rank, world_size)


When I run the cmd below on 16 GPUs(NVIDIA-A800-SXM4-80GB) with ngc-pytorch:24.02:
```bash
DISTRIBUTED_ARGS="--nproc_per_node $GPU_NUM --nnodes $NODE_NUM --node_rank $RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT";

torchrun $DISTRIBUTED_ARGS test_rmm.py

I got error:

2024-05-11 11:21 =================================
2024-05-11 11:21 ==== backtrace (tid:     92) ====
2024-05-11 11:21  0 0x0000000000042520 __sigaction()  ???:0
2024-05-11 11:21  1 0x00000000004ee038 cudbgMain()  ???:0
2024-05-11 11:21  2 0x0000000000333e77 cuSignalExternalSemaphoresAsync()  ???:0
2024-05-11 11:21  3 0x00000000000147fd ???()  /usr/local/cuda/lib64/libcudart.so.12:0
2024-05-11 11:21  4 0x0000000000073a93 cudaEventRecord()  ???:0
2024-05-11 11:21  5 0x000000000009d25d rmm::mr::detail::stream_ordered_memory_resource<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>, rmm::mr::detail::coalescing_free_list>::do_deallocate()  ???:0
2024-05-11 11:21  6 0x0000000000006538 deallocate()  ???:0
2024-05-11 11:21  7 0x0000000000b94094 torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete()  ???:0
2024-05-11 11:21  8 0x0000000000495110 c10::StorageImpl::~StorageImpl()  :0
2024-05-11 11:21  9 0x00000000000536d9 c10::TensorImpl::~TensorImpl()  TensorImpl.cpp:0
2024-05-11 11:21 10 0x0000000000f6dfbb std::_Sp_counted_ptr_inplace<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > >, (__gnu_cxx::_Lock_policy)2>::_M_dispose()  :0
2024-05-11 11:21 11 0x0000000000e661da std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release()  :0
2024-05-11 11:21 12 0x0000000000f3e753 c10d::ProcessGroupNCCL::WorkNCCL::~WorkNCCL()  ???:0
2024-05-11 11:21 13 0x0000000000f3e929 c10d::ProcessGroupNCCL::WorkNCCL::~WorkNCCL()  ProcessGroupNCCL.cpp:0
2024-05-11 11:21 14 0x0000000004d468f5 c10d::verify_params_across_processes()  ???:0
2024-05-11 11:21 15 0x0000000000bc6d55 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::shared_ptr<c10d::Logger> > const&)#84}, void, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::shared_ptr<c10d::Logger> > const&, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::shared_ptr<c10d::Logger> > const&)#84}&&, void (*)(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::shared_ptr<c10d::Logger> > const&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN()  init.cpp:0
2024-05-11 11:21 16 0x00000000003c9517 pybind11::cpp_function::dispatcher()  :0
2024-05-11 11:21 17 0x000000000015a10e PyObject_CallFunctionObjArgs()  ???:0
2024-05-11 11:21 18 0x0000000000150a7b _PyObject_MakeTpCall()  ???:0
2024-05-11 11:21 19 0x0000000000149629 _PyEval_EvalFrameDefault()  ???:0
2024-05-11 11:21 20 0x000000000015a9fc _PyFunction_Vectorcall()  ???:0
2024-05-11 11:21 21 0x000000000014326d _PyEval_EvalFrameDefault()  ???:0
2024-05-11 11:21 22 0x000000000015a9fc _PyFunction_Vectorcall()  ???:0
2024-05-11 11:21 23 0x000000000014fcbd _PyObject_FastCallDictTstate()  ???:0
2024-05-11 11:21 24 0x0000000000164a64 _PyStack_AsDict()  ???:0
2024-05-11 11:21 25 0x0000000000150a1c _PyObject_MakeTpCall()  ???:0
2024-05-11 11:21 26 0x000000000014a150 _PyEval_EvalFrameDefault()  ???:0
2024-05-11 11:21 27 0x000000000015a9fc _PyFunction_Vectorcall()  ???:0
2024-05-11 11:21 28 0x000000000014326d _PyEval_EvalFrameDefault()  ???:0
2024-05-11 11:21 29 0x000000000013f9c6 _PyArg_ParseTuple_SizeT()  ???:0
2024-05-11 11:21 30 0x0000000000235256 PyEval_EvalCode()  ???:0
2024-05-11 11:21 31 0x0000000000260108 PyUnicode_Tailmatch()  ???:0
2024-05-11 11:21 32 0x00000000002599cb PyInit__collections()  ???:0
2024-05-11 11:21 33 0x000000000025fe55 PyUnicode_Tailmatch()  ???:0
2024-05-11 11:21 34 0x000000000025f338 _PyRun_SimpleFileObject()  ???:0
2024-05-11 11:21 35 0x000000000025ef83 _PyRun_AnyFileObject()  ???:0
2024-05-11 11:21 36 0x0000000000251a5e Py_RunMain()  ???:0
2024-05-11 11:21 37 0x000000000022802d Py_BytesMain()  ???:0
2024-05-11 11:21 38 0x0000000000029d90 __libc_init_first()  ???:0
2024-05-11 11:21 39 0x0000000000029e40 __libc_start_main()  ???:0
2024-05-11 11:21 40 0x0000000000227f25 _start()  ???:0
2024-05-11 11:21 =================================
wence- commented 1 month ago

Can you try removing the lines:

local_rank = int(os.environ.get("LOCAL_RANK", 0))
rmm.reinitialize(pool_allocator=True, devices=local_rank)

And in demo_basic function, move the reinitialize call to after you have set the device with torch after the two lines:

    device_id = rank % torch.cuda.device_count()
    torch.cuda.set_device(device_id)
    rmm.reinitialize(pool_allocator=true, devices=device_id)

Also, can you provide information on your environment: what RMM version, etc... (If you have an RMM checkout, run the ./print_env.sh script).

GHGmc2 commented 1 month ago

Can you try removing the lines:

local_rank = int(os.environ.get("LOCAL_RANK", 0))
rmm.reinitialize(pool_allocator=True, devices=local_rank)

And in demo_basic function, move the reinitialize call to after you have set the device with torch after the two lines:

    device_id = rank % torch.cuda.device_count()
    torch.cuda.set_device(device_id)
    rmm.reinitialize(pool_allocator=true, devices=device_id)

Also, can you provide information on your environment: what RMM version, etc... (If you have an RMM checkout, run the ./print_env.sh script).

I use the built-in one from ngc-pytorch image, and found that it works(both of two ways) with rmm 24.2.0 from ngc-pytorch:24.04 (formerly I use rmm 23.12.0 from ngc-pytorch:24.02).

Thanks!

wence- commented 1 month ago

Ah, I suspect that the problem was that 23.12 did not have https://github.com/rapidsai/rmm/pull/1407, but 24.02 does. I'll go ahead and close.