Closed leofang closed 3 years ago
@tqchen You seem to have concerns for kDLCUDAManaged
(from https://github.com/dmlc/dlpack/issues/67#issuecomment-835845788 and https://github.com/dmlc/dlpack/pull/69#issuecomment-839843729). May I ask you what you had in mind please?
Thanks @leofang the main question is the usage of such memory in the current framework ecosystems. While it is certainly benefical to support such kind of memory, I believe most of the frameworks do not make use of this feature, as a result it might makes sense to wait a bit for framework adoption before we formally add them
@tqchen This is like a chick-and-egg problem. Something has to start first. This RFC costs nothing, and provides the needed convenience when more people are exploring the managed memory, so I disagree with your rejection based merely on "there's no one using it", because you never know.
In fact, managed memory is extensively used outside the ML/RL community. From the perspective of OpenMP/OpenACC, using managed memory is often the first encouraged step to start porting CPU codebase to accelerators, and developers would gradually move toward full utilization of device memory. (We actually have a user case using managed memory to make Python code talk to OpenMP code.)
From the perspective of Python Array API standard, the stakeholder projects aim at applications far beyond ML/RL. We are talking about scientific computing, data science, HPC, and beyond, all of which can be covered by the Array API to some extent. Therefore, it is actually good to have this available for easier switchover. I personally need this capability to be added to CuPy for example.
Thanks @leofang the main question is the usage of such memory in the current framework ecosystems.
I can confirm in RAPIDS we have the ability to use CUDA Managed Memory today and it would be nice if DLPack supported it to be able to hand managed memory off to other non-ML / non-DL things.
Just to add on to what Keith said. We've made use of managed memory in RAPIDS workflows. The advantage is it effectively automates spilling at the hardware level (bypassing what Dask, Spark, etc. would do manually). This can be considerably faster than spilling within software frameworks. Though it does come with tradeoffs (like whether libraries support it and how they opt to support it), but it is useful to have that option and does make sense for some workflows.
OK, given that there seems to be more support, then I agree we should add it in. Thanks everyone for chiming in
Thanks for support @kkraus14 @jakirkham @tqchen! Will send a PR in a moment.
Let us keep this open for about a week for more feedbacks then we can merge. Thanks @leofang
Thanks, @tqchen. One minor issue we better address is how to fill in device_id
in DLDevice
. In the examples in this repo it seems to use 0 for kDLCPU
. Should we apply the same rule to kDLCUDAHost
, kDLROCmHost
and kDLCUDAManaged
, and specify this requirement in the Array API standard? These device types should not have the concept of device ID either, just like kDLCPU
.
Alternatively, for the above scenarios we could either set device_id
to -1 (since it's int
) or ask all libraries and users to ignore it.
As far as DLPack is concerned, we just need to add a comment in dlpack.h
to specify this. If everyone agrees I'll do this in #71.
I think specifying 0 is OK for unified / managed memory to be consistent with our previous convention
Hi @tqchen and all, it's been a little over a week and I take it that everyone is happy with this RFC and the PR #71?
Hi @leofang, sorry to bother you. I have a question about (or beyond) kDLCUDAHost
and kDLCUDAManaged
:
Besides the page-locked memory allocated by cudaHostAlloc()
and UM managed by cudaMallocManaged()
, we are using cudaHostRegister()
to pin and map existing host memory (maybe previously allocated by host malloc
) to the page-locked memory for GPU zero-copy access. This allows the pinned memory to be shared among multiple processes and is useful for graph neural network (GNN), where we have multiple GPUs access and do some sampling (e.g., random walk) on the large pinned graph structure. These memory are previously on device kDLCPU
but have cudaHostRegister()
/cudaHostUnregister()
to pin&map/unpin&unmap, which is quite different from kDLCUDAHost
and kDLCUDAManaged
. Do you have any suggestions on which device type should we use for these pinned memory? Do we need a new device type for this?
@leofang BTW,
These device types should not have the concept of device ID either, just like
kDLCPU
.
This is not always true I think. cudaHostAlloc()
has a flag cudaHostAllocPortable
, the allocated memory is considered as pinned memory by all CUDA contexts only when this flag is specified (default). In other cases, device ID does make sense for the kDLCUDAHost
device.
Hi @yaox12:
These memory are previously on device
kDLCPU
but havecudaHostRegister()
/cudaHostUnregister()
to pin&map/unpin&unmap, which is quite different fromkDLCUDAHost
andkDLCUDAManaged
. Do you have any suggestions on which device type should we use for these pinned memory? Do we need a new device type for this?
No we don't need a new device type. Depending on how your array container handles things, one approach could be that after you register (de-resigster) the memory you switch the device type to kDLCUDAHost
(kDLCPU
). Another approach you could take (as is done in NumPy) is to accept kDLCUDAHost
as a valid device type for CPU processing so that you don't have to worry about the device type as long as it's CPU accessible. Again, it depends on what you need and how your library/framework is implemented.
To your other question:
These device types should not have the concept of device ID either, just like
kDLCPU
.This is not always true I think.
cudaHostAlloc()
has a flagcudaHostAllocPortable
, the allocated memory is considered as pinned memory by all CUDA contexts only when this flag is specified (default). In other cases, device ID does make sense for thekDLCUDAHost
device.
cudaHostAllocPortable
allows multiple GPUs (each with a separate CUDA context) to see the same pinned memory, so which GPU ID should we use?)Finally, I have a question for you as an aside:
This allows the pinned memory to be shared among multiple processes and is useful for graph neural network (GNN), where we have multiple GPUs access and do some sampling (e.g., random walk) on the large pinned graph structure.
I don't think you meant to say multiple "processes"; I think you meant multiple "threads", since a segment of page-locked memory in one process can't be seen page-locked in other processes.
@leofang Thanks for your explanation.
One approach could be that after you register (de-resigster) the memory you switch the device type to kDLCUDAHost (kDLCPU). Another approach you could take (as is done in NumPy) is to accept kDLCUDAHost as a valid device type for CPU processing so that you don't have to worry about the device type as long as it's CPU accessible.
Since we are accessing the registered memory from GPUs through unified virtual addressing space, we prefer the former. However, we are wondering if it would cause confusion when we have both cudaHostRegister()
and cudaHostAlloc()
page-locked memory in our framework.
I don't think you meant to say multiple "processes"; I think you meant multiple "threads", since a segment of page-locked memory in one process can't be seen page-locked in other processes.
For this question, I do mean multiple "processes". We first copy the target memory to shared memory in the main process, and register it in every forked/spawned subprocess. Since cudaHostRegister()
is an inplace method, all processes can access this page-locked and shared memory. For example, you can do this in PyTorch:
import torch
import torch.multiprocessing as mp
def pin_shared(i, x):
cudart = torch.cuda.cudart()
r = cudart.cudaHostRegister(x.data_ptr(), x.numel() * x.element_size(), 0)
assert x.is_shared()
assert x.is_pinned()
x = torch.arange(5).share_memory_()
mp.spawn(pin_shared, (x,), nprocs=4)
From https://github.com/dmlc/dlpack/issues/67#issuecomment-834784656 and https://github.com/dmlc/dlpack/issues/67#issuecomment-835669619:
I am proposing to add two new device types, following the line of #67:
kDLROCMHost
kDLCUDAManaged
The first addition is to mirror the current relation (since v0.5) between CUDA and ROCm now that we have
kDLCUDA
,kDLCUDAHost
, andkDLROCM
. ROCm also provides pinned/page-locked memory, so this is legit.The second addition is for CUDA managed/unified memory, which does not belong to either host or device but to both. It seems natural to me to have a standalone type for it. ROCm currently does not provide managed memory, so we could add it in the future once AMD implements it.
Both additions seem straightforward for me to add without any issue, as they are orthogonal to existing device types (as they should be).
cc: @rgommers @tqchen @jakirkham @kkraus14 @kmaehashi @emcastillo @asi1024