Closed li-yi-dong closed 10 months ago
torch.cuda.change_current_allocator(rmm_torch_allocator) pool = rmm.mr.PoolMemoryResource( rmm.mr.CudaMemoryResource(), initial_pool_size=2**30, ) device = (int(os.environ['LOCAL_RANK'])) rmm.mr.set_per_device_resource(device, pool)
This is almost correct. However, when creating a memory resource, you should do so with the target device active.
That is, you must call cudaSetDevice(device)
before creating the pool and then calling set_per_device_resource
.
This can be done like so:
import rmm
device = (int(os.environ['LOCAL_RANK']))
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(...)
rmm.mr.set_per_device_resource(device, pool)
Since this is such a common pattern, the top-level rmm.reinitialize
has some logic to handle this for you:
import rmm
device = int(os.environ["LOCAL_RANK"])
rmm.reinitialize(devices=device, pool_allocator=True, initial_pool_size=...)
This doesn't have quite as much flexibility on the set up of the allocator, but if you just need a pool on top of a cuda memory resource then it works fine.
We could add an interface whereby you provide a zero-argument callback to construct the pool (and rmm.reinitialize
would arrange to call it with the correct device active), but we haven't needed it so far.
@wence- Thanks for your reply! I tried with
torch.cuda.change_current_allocator(rmm_torch_allocator)
device = (int(os.environ['LOCAL_RANK']))
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=2**30,
)
rmm.mr.set_per_device_resource(device, pool)
Unfortunately, it coredumped like
Hmm, we fixed some bugs around stream ordered memory resources that will be in 23.12, but are not 23.10. It's possible that using 23.12 will fix things.
Can you provide a complete example script to run and I will try and reproduce locally.
Hmm, we fixed some bugs around stream ordered memory resources that will be in 23.12, but are not 23.10. It's possible that using 23.12 will fix things.
Can you provide a complete example script to run and I will try and reproduce locally.
Let me try the 23.12
I tried this trivial code:
import os
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch
device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=2**30,
)
rmm.mr.set_per_device_resource(device, pool)
print(torch.zeros(2, 3, device=f"cuda:{device}"))
When I run with torchrun --nnodes 1 --nproc-per-node gpu test.py
I don't get any errors (both with 23.10 and 23.12)
I tried this trivial code:
import os import rmm from rmm.allocators.torch import rmm_torch_allocator import torch device = (int(os.environ['LOCAL_RANK'])) torch.cuda.change_current_allocator(rmm_torch_allocator) rmm._cuda.gpu.setDevice(device) pool = rmm.mr.PoolMemoryResource( rmm.mr.CudaMemoryResource(), initial_pool_size=2**30, ) rmm.mr.set_per_device_resource(device, pool) print(torch.zeros(2, 3, device=f"cuda:{device}"))
When I run with
torchrun --nnodes 1 --nproc-per-node gpu test.py
I don't get any errors (both with 23.10 and 23.12)
Emmm, I tried this code and got some interesting results.
I run the sample code with torchrun --nnodes 1 --nproc-per-node 8 rmm_test.py
It cored
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING]
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] *****************************************
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] *****************************************
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:1')
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:2')
[2023-12-11 16:55:17,290] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 3 (pid: 216181) of binary: /opt/conda/envs/rmm_dev2/bin/python3.10
Traceback (most recent call last):
File "/opt/conda/envs/rmm_dev2/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
run(args)
File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
elastic_launch(
File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
The core file:
#0 0x00007fd2b6b2e0b8 in ?? () from /lib64/libcuda.so.1
#1 0x00007fd2b6973fd7 in ?? () from /lib64/libcuda.so.1
#2 0x00007fd2f7621a58 in ?? () from /opt/conda/envs/rmm_dev2/lib/libcudart.so.12
#3 0x00007fd2f767901b in cudaEventRecord () from /opt/conda/envs/rmm_dev2/lib/libcudart.so.12
#4 0x00007fd2f75bcb56 in rmm::mr::detail::stream_ordered_memory_resource<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>, rmm::mr::detail::coalescing_free_list>::do_deallocate(void*, unsigned long, rmm::cuda_stream_view) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/rmm/_lib/memory_resource.cpython-310-x86_64-linux-gnu.so
#5 0x00007fd247d4e3ce in deallocate () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/rmm/_lib/torch_allocator.cpython-310-x86_64-linux-gnu.so
#6 0x00007fd2b60460e3 in torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#7 0x00007fd2b5979c06 in c10::StorageImpl::~StorageImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#8 0x00007fd272b58ca7 in c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >::reset_() ()
from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#9 0x00007fd272b50cb3 in c10::TensorImpl::~TensorImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#10 0x00007fd272b50e49 in c10::TensorImpl::~TensorImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#11 0x00007fd2a02a1a64 in at::native::isfinite(at::Tensor const&)::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const [clone .isra.0] ()
from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fd2a02a2777 in at::native::isfinite(at::Tensor const&) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fd2a1228ddd in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) ()
from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fd2a0c982eb in at::_ops::isfinite::call(at::Tensor const&) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fd2b5a79453 in torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#16 0x00005594adc107e6 in cfunction_call (func=0x7fd2f2f00e00, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:543
I tried multiple times and noticed that every time it could only print the tensors on device 0,1,2. So, I tried with torchrun --nnodes 1 --nproc-per-node 3 test.py
, and it worked just fine. It worked fine when number of gpu is under 4. When the number of gpu >= 4, it cored.
I tried with RMM v23.12.00, Python 3.10 and PyTorch 2.1.1
Thanks, I'll try and reproduce on a system with more than two GPUs.
Is it possible that the active device could be changing before the deallocate
is (implicitly) called? The error in cudaEventRecord
makes me think that it may be trying to record an event on the wrong device. This MR expects the device that was active when the pool was created to be active when any call to allocate()
or deallocate()
is made.
Is it possible that the active device could be changing before the
deallocate
is (implicitly) called?
I don’t think the trivial code nor PyTorch would do so. That could not explain why less than 4 GPUs worked.
I modified the code into
import os
import time
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch
device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)
print(rmm._cuda.gpu.getDevice())
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=2**30,
)
rmm.mr.set_per_device_resource(device, pool)
a = torch.zeros(2, 3, device=f"cuda:{device}")
print(a)
print(rmm._cuda.gpu.getDevice())
del a
time.sleep(5)
print(rmm._cuda.gpu.getDevice())
The output
(rmm_dev2) sh rmm.sh
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING]
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] *****************************************
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] *****************************************
6
3
7
2
1
0
4
5
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:2')
2
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:0')
0
tensor([[0., 0., 0.],
[0., 0., 0.]], device='cuda:1')
1
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236161 closing signal SIGTERM
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236162 closing signal SIGTERM
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236163 closing signal SIGTERM
[2023-12-12 14:33:39,029] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 3 (pid: 236164) of binary: /opt/conda/envs/rmm_dev2/bin/python3.10
It seems that the tensors on devices 3,4,5,6,7 has been deallocated before print
. (print the tensor on GPU will synchronize the CPU and GPU in PyTorch)
I was able to reproduce running with four GPUs, I have yet to figure out what is going on. Debugging under gdb is difficult here because torchrun is running things in processes, but. If we run in gdb with set detach-on-fork off
and set follow-fork-mode child
, eventually we can get to the relevant process and I can get a backtrace.
Next step is to build RMM in debug mode so I have some symbols to inspect.
This is what I have right now to debug, note I only need to allocate things on a single device:
import os
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch
device = (int(os.environ['LOCAL_RANK']))
if device == 3:
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=2**30,
)
rmm.mr.set_per_device_resource(device, pool)
tensor = torch.zeros(2, 3, device=f"cuda:{device}")
print(torch.cuda.current_device(), device, os.getpid(), flush=True)
print(tensor, flush=True)
del tensor
So my suspicion is that torch shuffling cuda devices out from under us in a bad way.
Thanks so much for debugging, @wence- .
OK, I have the culprit.
The signature we offer for the plug in allocation functions is:
void *allocate(size_t size, int device, cudaStream_t stream);
void deallocate(void *ptr, size_t size, cudaStream_t stream);
Which was the original signature for the pluggable allocators when we introduced this in https://github.com/rapidsai/rmm/pull/1168, introduced in pytorch in https://github.com/pytorch/pytorch/pull/86786
But soon after, in https://github.com/pytorch/pytorch/pull/91398 the signatures were changed to:
void *allocate(size_t size, int device, cudaStream_t stream);
void deallocate(void *ptr, size_t size, int device, cudaStream_t stream);
Note the change to also accept the device in the deallocate function.
So we're getting 3
(as device), interpreting it as a stream and trying to use that in the RMM deallocation function. But of course that stream is nonsense, everyone is actually just using the default 0
stream.
The fix is the fix the signature in RMM (I will prepare a patch).
Here is a minimal diff that will allow your code to run:
diff --git a/python/rmm/_lib/torch_allocator.pyx b/python/rmm/_lib/torch_allocator.pyx
index 12dc9fe1..2b11028c 100644
--- a/python/rmm/_lib/torch_allocator.pyx
+++ b/python/rmm/_lib/torch_allocator.pyx
@@ -15,7 +15,7 @@ cdef public void* allocate(
return mr[0].allocate(size, stream_view)
cdef public void deallocate(
- void* ptr, ssize_t size, void* stream
+ void* ptr, ssize_t size, int device, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view(
However, in #1407 I am trying to do a better thing, which is to use the memory resource associated with the device we are being passed, rather than just assuming that get_current_device_resource
will return the correct resource. That needs some help from someone on the build side of things
Can you try if the code in #1408 works for you @li-yi-dong?
Can you try if the code in #1408 works for you @li-yi-dong?
I works pretty smooth with my task. And the RMM really outperforms the PyTorch caching allocator in terms of fragmentation.
Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful.
Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful.
It works fine.
Describe the bug I tried to use RMM with PyTorch. I launch my task with torchrun and set the rmm.mr for each device at the very beginning.
But each process occupies a chunk of memory on GPU0 like
Steps/Code to reproduce bug
Expected behavior I expected each process launched by torchrun only uses the memory on the GPU assigned by
LOCAL_RANK
Environment details (please complete the following information): I'm using RMM v23.10.00 Here is the output of the print_env.sh
Additional context Add any other context about the problem here.