bheisler / RustaCUDA

Rusty wrapper for the CUDA Driver API
Apache License 2.0
764 stars 58 forks source link

Figure out a way to make the context API truly safe #2

Open bheisler opened 5 years ago

bheisler commented 5 years ago

Maybe someone can find a way to make the Context API more safe than it is. I haven't been able to think of anything so far.

rusch95 commented 5 years ago

My advice is probably checking out some of the other CUDA wrapper implementations for other languages. PyCUDA seems to be the most popular non C/C++ binding set, and I'm sure whatever the Haskell binding is doing will be too safe.

Plus, the PyCUDA has some really nifty wrappers to simplify the API that could be used for inspiration.

ctrl-z-9000-times commented 4 years ago

Hello,

I have an idea for how to make a better Context API. This is inspired by the Python keyword with.

with_context!( ctx, {
    Code here uses context ...
})

This macro would expand to:

{
    Push the context onto thread local context stack.
    let ret_val = evaluate_given_code_block();
    Pop the context off of thread local context stack.
    return ret_val;
}

Inside of the block, the user can use the global context. The context will always be cleaned up at the end.

This same technique is also applicable for Devices.

ctrl-z-9000-times commented 4 years ago

Another thing to help manage contexts is: CUDA keeps a ref-count for every context. These are exposed as cuCtxAttach and cuCtxDetach. https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#context

Using this ref-count could make your UnownedContext obsolete?

charles-r-earp commented 4 years ago

I like what ctrl-z is thinking with with_context!, but a macro can only convert input to public methods. What about using a special guard object to do the pre / post actions?

/// Wrap to impl Send
#[doc(hidden)]
pub struct SendContext(Context);

impl SendContext {
    fn as_context(&self) -> &Context {
        &self.0
    }
}

pub struct SyncContext {
    context: Mutex<SendContext>    
}

impl SyncContext {
    pub fn lock() -> LockResult<ContextGuard> { // Maybe have own lock result that includes cuda result, to avoid panicking?
        match context.lock() {
            Ok(guard) => {
                CurrentContext::set_current(guard)
                    .unwrap();
                Ok(guard)
            },
            Err(poison) => {
                let guard = poison.into_inner();
                CurrentContext::set_current(guard)
                    .unwrap();
                PoisonError::new(guard)
            } 
        }
    }
}

pub struct ContextGuard<'a> {
    context: MutexGuard<'a, SendContext>
}

impl<'a> Drop for ContextGuard<'a> {
    fn drop(&mut self) {
        let current = CurrentContext::get_current()
            .unwrap();
        if current == self.context {
            ContextStack::pop()
                .unwrap();
        }
        else {
            /// either panic or do nothing 
        }
    }
}

impl Deref for ContextGuard {
    type Target = Context;
    fn deref(&self) -> &Context {
        self.context.as_context()
    }
}

struct SyncDeviceBuffer<T> {
    buffer: DeviceBuffer<T>,
    context: SyncContext<T>
}

impl<T> SyncDeviceBuffer<T> {
    pub unsafe fn uninitialized(context: Arc<SyncContext>, size: usize) -> CudaResult<Self> {
        let _c = context.lock()?;
        let buffer = DeviceBuffer::uninitialized(size)?;
        Ok(Self {
            buffer,
            context
        })
    }
}

impl<T> Deref for SyncDeviceBuffer<T> {
    type Target = &DeviceSlice<T>;
    fn deref(&self) -> &Self::Target {
        &*self.buffer
    }
}