google / wasefire

Secure firmware framework focusing on developer experience
https://google.github.io/wasefire/
Apache License 2.0
79 stars 20 forks source link

Casting handler_data back into a closure. #460

Open lukeyeh opened 4 months ago

lukeyeh commented 4 months ago

I was playing around with some of the applet API functions namely ones where you can give callbacks to. It looks like there's a common pattern where we pass some function pointer called the handler_func and also a handler_data down into the platform for example in timer we can see here,

https://github.com/google/wasefire/blob/a0110a85f50b6f8cd65d7ccc41658f8797c1f5f8/crates/prelude/src/timer.rs#L56-L61

that we first box the closure supplied by the user, cast it to a u8 then have a function pointer on the Timer struct to later cast it back into a closure that we can invoke:

https://github.com/google/wasefire/blob/a0110a85f50b6f8cd65d7ccc41658f8797c1f5f8/crates/prelude/src/timer.rs#L104-L105

I guess this is similar to the pattern in C++ where you reinterpret_cast<intptr_t> something like a pointer to some user_data

I am however having troubles with casting some closures that have been boxed and casted to a u8 back into the original closure.

Here is what I have:

impl Future for OneShotTimerFuture {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.done.get() {
            return Poll::Ready(());
        }
        let handler_func = Self::call;
        let handler = || {
            debug!("in handler");
            self.done.set(true);
            cx.waker().wake_by_ref();
        };
        let handler = Box::into_raw(Box::new(handler));
        let handler_data = handler as *const u8;
        let params = api::allocate::Params { handler_func, handler_data };
        let id = convert(unsafe { api::allocate(params) }).unwrap();
        let params =
            api::start::Params { id, mode: Oneshot as usize, duration_ms: self.duration_ms };
        convert_unit(unsafe { api::start(params) }).unwrap();
        Poll::Pending
    }
}

impl OneShotTimerFuture {
    pub fn new(duration_ms: usize) -> Self {
        OneShotTimerFuture { duration_ms, done: false.into() }
    }

    extern "C" fn call(data: *const u8) {
      // What to do here?
    }
}

I've tried to use a core::mem::transmute, but couldn't really get anywhere with that and keep hitting some form of "cannot transmute between different size" compiler errors. Its probably due to the fact that all my solutions are attempting to cast from a *mut u8 to a *mut dyn Fn().

I see that most of the prelude functions utilize the Handler trait, which I assume allows some of the types to align nicely

https://github.com/google/wasefire/blob/a0110a85f50b6f8cd65d7ccc41658f8797c1f5f8/crates/prelude/src/timer.rs#L29

and cast back to that later:

https://github.com/google/wasefire/blob/a0110a85f50b6f8cd65d7ccc41658f8797c1f5f8/crates/prelude/src/timer.rs#L104

but my struct here doesn't need to be generic on different types of "handlers" so I'm not sure if that's appropriate here. Any help would be appreciated. Thanks in advance!

lukeyeh commented 4 months ago

My only working solution is to store the closure, wrapped in a box in a struct. Is there a better way to do this?

struct Data<'a>
{
    f: Box<dyn Fn() + 'a>,
}

impl Future for OneShotTimerFuture {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.done.get() {
            return Poll::Ready(());
        }
        let handler_func = Self::call;
        let handler = Data { 
            f: Box::new(|| {
                self.done.set(true);
                cx.waker().wake_by_ref();
            }), 
        };
        let handler = Box::into_raw(Box::new(handler));
        let handler_data = handler as *const u8;
        let params = api::allocate::Params { handler_func, handler_data };
        let id = convert(unsafe { api::allocate(params) }).unwrap();
        let params =
            api::start::Params { id, mode: Oneshot as usize, duration_ms: self.duration_ms };
        convert_unit(unsafe { api::start(params) }).unwrap();
        Poll::Pending
    }
}

impl OneShotTimerFuture {
    pub fn new(duration_ms: usize) -> Self {
        OneShotTimerFuture { duration_ms, done: false.into() }
    }

    extern "C" fn call(data: *const u8) {
        let ptr = unsafe { &*(data as *const Data) };
        (ptr.f)();
    }
}
ia0 commented 4 months ago

I hope I understood the problem correctly, in which case a part of the solution would be to do something more like serial::Listener than timer::Timer. Let me try to give an overview of a solution.

pub struct Timer {
    id: usize,
    state: &'static Cell<TimerState>,
}

enum TimerState {
    Idle,
    Sleep(&'static RefCell<SleepState>), // using the same terminology as tokio::time
    // Interval(&'static Interval),
}

// We use an exclusive lifetime to the timer to avoid users accidentally using the timer twice concurrently.
// They can still leak this object but that's not an accident then.
// Alternatively would could just return an error if something is running.
pub struct Sleep<'a> {
    state: &'static RefCell<SleepState>,
    timer: &'a mut Timer,
}

enum SleepState {
    Done,
    Running,  // not yet polled
    Awaiting(Waker),  // polled
}

impl Timer {
    pub fn new() -> Self {
        // leak a TimerState::Idle
        // allocate timer with func=call and data=state
    }

    // We take exclusive reference to the timer even if we don't need mutability, simply to avoid double usage.
    pub fn sleep(&mut self, duration: Duration) -> Sleep {
        self.idle();
        // start the timer
        // leak a SleepState::Running
        // update the timer state
    }

    // Drops any running timer and go back to idle state.
    fn idle(&self) {
        // if Sleep, it means the Sleep object was leaked, it should be fine to just unleak the SleepState and drop it
        // We should also stop the timer if running
    }

    extern "C" fn call(data: *const u8) {
        // match on the state
        // nothing to do for Idle (this could happen with spurious calls)
        // for Sleep, wake the waker if any, and it all cases move to Done
    }
}

// Should call self.idle(), unregister, and unleak&drop the timer state.
impl Drop for Timer { ... }

impl<'a> Future for Sleep<'a> {
    fn poll(...) -> Poll {
        // check the SleepState
        // if Done, then Ready
        // if Running, then clone the Waker and move to Awaiting
        // if Awaiting, it means the future was polled twice before awoken, just update the Waker in place
    }
}

// It might be enough to call timer.idle().
impl Drop for Sleep { ... }

There are other variants possible and it's not yet clear to me what would be best. But from a high-level point of view, I think the API should kind of look like this:

pub struct Timer;
pub struct Sleep;  // whether there is a lifetime parameter is to be decided
pub struct Interval;
impl Timer {
    pub fn new() -> Self;
    pub fn sleep(&self, duration: Duration) -> Sleep; // whether taking &mut self is to be decided
    pub fn interval(&self, duration: Duration) -> Interval;
}
impl Future for Sleep {}
impl Interval {
    pub async fn tick(&mut self);
}

Maybe we should provide the following convenience functions (implemented using the API above):

pub fn sleep(duration: Duration) -> Sleep;
pub fn interval(duration: Duration) -> Interval;

However, they might need to return slightly different types because we need to store the timer somewhere and know we own it.

lukeyeh commented 4 months ago

Thanks for the pointers and being patient with me! Not sure why I was so fixated on sticking the waker.wake() in the closure and using that as the data!

ia0 commented 4 months ago

No worries. Feel free to close if you don't have follow-up questions.