bheisler / RustaCUDA

Rusty wrapper for the CUDA Driver API
Apache License 2.0
765 stars 58 forks source link

Idea: Mimic `std::boxed::Box` API as closely as possible #63

Open chr1sj0nes opened 2 years ago

chr1sj0nes commented 2 years ago

Apologies, this seems to have become rather long. It seemed like such a simple idea at the start!

There are four main tracts to this idea:

  1. Allow DeviceBox<[T]>, making DeviceBuffer just an alias that could be deprecated in future.
  2. Make the interface safer by using MaybeUninit for uninitialized/zeroed allocations on the device.
  3. Add an Alloc generic parameter to DeviceBox, allowing for various new type of allocation.
  4. Bonus: Add support for async allocations.

I think this proposal is entirely backwards compatible, though it does introduce some methods that are very similar to existing, e.g. new_unit vs. uninitialized, new_zeroed vs. zeroed.

DeviceAllocator

// new `alloc` module

pub trait DeviceAllocator {
    type Ptr;

    fn allocate(&self, size: usize) -> CudaResult<Ptr>;
    // This allows for asynchronous zeroing.
    fn allocate_zeroed(&self, size: usize) -> CudaResult<Ptr>;
    fn deallocate(&self, ptr: Ptr) -> CudaResult<()>;
}

// Uses `cudaMalloc`, `cudaFree`.
pub struct Global; // TODO better name?

impl DeviceAllocator for Global {
    type Ptr = DevicePointer<u8>;
    ...
}

// Other allocators might include:
// `Unified`, `HostPinned`, `Pitched`, `Async`, `MemoryPool`, etc.
pub struct DeviceBox<T, A: DeviceAllocator = Global> {
    ptr: A::Ptr,
    alloc: A,
}

impl<T, A> DeviceBox<T, A> {
    pub fn new_in(x: T, alloc: A) -> DeviceBox<T, A>;
}

MaybeUninit

impl<T> DeviceBox<T, Global> {
    ...
    // Note that these methods are safe.
    pub fn new_uninit() -> DeviceBox<MaybeUninit<T>, Global>;
    pub fn new_zeroed() -> DeviceBox<MaybeUninit<T>, Global>;

}

impl<T, A> DeviceBox<T, A> {
    ...
    pub fn new_uninit_in(alloc: A) -> DeviceBox<MaybeUninit<T>, A>;
    pub fn new_zeroed_in(alloc: A) -> DeviceBox<MaybeUninit<T>, A>;
}

impl<T, A> DeviceBox<MaybeUninit<T>, A> {
    pub unsafe fn assume_init(self) -> DeviceBox<T, A>;

    // Use this for kernel outputs, then `assume_init` after the kernel is complete.
    pub unsafe fn as_uninit_device_pointer(&mut self) -> DevicePointer<T>;
}

DeviceBox<[T]>

impl<T> DeviceBox<[T], Global> {
    pub fn new(x: &impl AsRef<[T]>) -> DeviceBox<[T], Global>;
    pub fn new_uninit_slice() -> DeviceBox<[MaybeUninit<T>], Global>;
    pub fn new_zeroed_slice() -> DeviceBox<[MaybeUninit<T>], Global>;
}

impl<T, A> DeviceBox<[T], A> {
    pub fn new_in(x: &impl AsRef<[T]>, alloc: A) -> DeviceBox<[T], A>;
    pub fn new_uninit_slice_in(alloc: A) -> DeviceBox<[MaybeUninit<T>], A>;
    pub fn new_zeroed_slice_in(alloc: A) -> DeviceBox<[MaybeUninit<T>], A>;
}

impl<T, A> DeviceBox<[MaybeUninit<T>], A> {
    pub unsafe fn assume_init(self) -> DeviceBox<[T], A>;
    pub unsafe fn as_uninit_device_pointer(&mut self) -> DevicePointer<T>;
}

Async

// Uses `cudaMallocAsync`, `cudaFreeAsync`.
pub struct Async<'a> {
    stream: &'a Stream,
}

impl Async<'_> {
    pub fn on(stream: &'a Stream) -> Async<'a>;
}

impl<'a> DeviceAllocator for Async<'a> {
    type Ptr = DevicePointerAsync<'a, u8>;
    ...
}

pub struct DevicePointerAsync<'a, T> {
    ptr: DevicePointer<T>,
    stream: &'a Stream,
    is_allocated: Event,
}

impl<T, A> DeviceBox<T, A>
where
    A: DeviceAllocator,
    A::Ptr = DevicePointerAsync<'_, T>,
{
    // If the stream matches the async pointer, return it immediately.
    // Otherwise, block `stream` on `is_allocated` event.
    pub fn as_device_pointer_on(&mut self, stream: &Stream) -> DevicePointer<T>;
    pub unsafe fn as_device_pointer_unchecked(&mut self) -> DevicePointer<T>;
}

impl<T> DeviceBox<T, Async<'_>> {
    // Wait for `is_allocated` event.
    pub fn synchronize(self) -> DeviceBox<T, Global>;
}