Closed wingertge closed 4 weeks ago
@wingertge Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type. The implementation looks good though.
Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type.
The way this is currently done in existing tuned operations is to initialize an output tensor and pass it to the kernel
let output = init_matmul_output(&lhs, &rhs);
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
TUNER.execute(
&JitTuneId::new::<R>(&lhs.device),
&client,
Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())),
);
output
This works fine, as long as the operation you're passing it to ignores the can_mut
flag. However, doing any postprocessing on the output tensor using functions that do respect can_mut
(i.e. float_matmul
, float_slice_assign
) will copy the tensor content and the output will no longer be written to the output tensor. This makes things like im2col
impossible without circumventing the backend and manually launching the low level kernel for each operation. Mutable references don't work because operations must have 'static
lifetime.
Output
allows us to instead return the output tensor from whatever postprocessing operations we need without unsafe workarounds like swapping the inner handles. Algorithms could then just return their output as normal:
/// Executes autotune on conv2d operations
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement + Element, I: IntElement>(
input: JitTensor<R, E, 4>,
weights: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let client = input.client.clone();
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!("conv2d");
TUNER.execute(
&JitTuneId::new::<R>(&input.device),
&client,
Box::new(Conv2dOperationsSet::<R, E, I>::new(
input, weights, bias, options,
)),
)
}
This gives a lot more implementation flexibility.
@wingertge Awesome, thanks a lot for the detailed explanation. I agree that this simplifies some workflows.
Adds an optional output type to tuner executions. The type is defined at the type of calling
execute
and not on theTuner
itself because theTuner
is static and cannot depend on the generic types passed to the tensor operation function. This is the only way to make things likeJitTensor<R, E, 4>
work as output types. The new output type defaults to()
, making changes to existing tuners unnecessary.Testing
All unit tests pass and existing tuners in
burn
compile fine and work exactly the same. I also tested the new output type to autotune differentconv2d
algorithms and it works as expected, with all tests passing.