tracel-ai / cubecl

Multi-platform high-performance compute language extension for Rust.
https://burn.dev
Apache License 2.0
653 stars 29 forks source link

Support closures in kernels #114

Open RianGoossens opened 1 month ago

RianGoossens commented 1 month ago

Currently it's not possible to write a closure like so:

#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
    if ABSOLUTE_POS < input.len() {
        let f = |x: F| x * F::erf(x / F::new(2.0f32.sqrt()) + F::new(1.0)) / F::new(2.0);
        output[ABSOLUTE_POS] = f(input[ABSOLUTE_POS]);
    }
}

I'm not sure if this is even possible, though I could see it working in limited circumstances. I haven't written many proc macros, so there might be a reason this is infeasible, but a way to tackle this could be by inlining them.