tracel-ai / cubecl

Multi-platform high-performance compute language extension for Rust.
https://burn.dev
Apache License 2.0
535 stars 21 forks source link

Add optional output to autotuned operations #73

Closed wingertge closed 4 weeks ago

wingertge commented 1 month ago

Adds an optional output type to tuner executions. The type is defined at the type of calling execute and not on the Tuner itself because the Tuner is static and cannot depend on the generic types passed to the tensor operation function. This is the only way to make things like JitTensor<R, E, 4> work as output types. The new output type defaults to (), making changes to existing tuners unnecessary.

Testing

All unit tests pass and existing tuners in burn compile fine and work exactly the same. I also tested the new output type to autotune different conv2d algorithms and it works as expected, with all tests passing.

nathanielsimard commented 1 month ago

@wingertge Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type. The implementation looks good though.

wingertge commented 1 month ago

Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type.

The way this is currently done in existing tuned operations is to initialize an output tensor and pass it to the kernel

let output = init_matmul_output(&lhs, &rhs);

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

TUNER.execute(
    &JitTuneId::new::<R>(&lhs.device),
    &client,
    Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())),
);

output

This works fine, as long as the operation you're passing it to ignores the can_mut flag. However, doing any postprocessing on the output tensor using functions that do respect can_mut (i.e. float_matmul, float_slice_assign) will copy the tensor content and the output will no longer be written to the output tensor. This makes things like im2col impossible without circumventing the backend and manually launching the low level kernel for each operation. Mutable references don't work because operations must have 'static lifetime.

Output allows us to instead return the output tensor from whatever postprocessing operations we need without unsafe workarounds like swapping the inner handles. Algorithms could then just return their output as normal:

/// Executes autotune on conv2d operations
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement + Element, I: IntElement>(
    input: JitTensor<R, E, 4>,
    weights: JitTensor<R, E, 4>,
    bias: Option<JitTensor<R, E, 1>>,
    options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
    let client = input.client.clone();

    static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!("conv2d");

    TUNER.execute(
        &JitTuneId::new::<R>(&input.device),
        &client,
        Box::new(Conv2dOperationsSet::<R, E, I>::new(
            input, weights, bias, options,
        )),
    )
}

This gives a lot more implementation flexibility.

nathanielsimard commented 4 weeks ago

@wingertge Awesome, thanks a lot for the detailed explanation. I agree that this simplifies some workflows.