Open pfultz2 opened 2 months ago
And this could be extended for variadic parameters by using recursion:
Actually this is not needed. auto_preload
will do this for multiple parameters, and its setup as an arg transformer. We should probably make the slicing an arg transformer as well so we can just pass it the current pointwise function and we dont need another pointwise overload for doing preloading(we can also then resuse this for other kernels if we want).
To improve performance for transpose kernels we should load the transposed inputs into LDS directly, and then read from LDS instead. We have function like
preload_copy
which will do this preloading with a vectorize copy, but this will preload the entire tensor.Instead we should only preload a slice. Ideally this slice will be across the transposed axis, so in the case of NHWC we would slice across the C axis(ie axis 1) which will give us slices of CHW. The slice may contain more than one channel for the case where HW is very small as we want to make sure each workgroup has engough work to do.
The
preload_copy
function takes a boolean which enables preloading for the tensor. It uses a lambda callback with the new tensor so the LDS can be statically allocated and its lifetime will live as long as the tensor view. We should be able to reuse this function to do the preloading, we just need to do the slicing:Functions like
make_slice
andshould_preload_transposed_slice
will need to be implemented. And this could be extended for variadic parameters by using recursion:Currently,
preload_copy
will only work with packed tensors because it uses the original tensor shape to read the LDS buffer. We can make packed be a restriction at first, but we could get unpacked support working by tweak thepreload_copy
in the future.Since we are reading a slice per workgroup the global calculation will need to be updated.
Since we are preloading transposed inputs we should prefer that the inputs are transposed rather than the outputs. This might require some adjustment to the permutation that we pass to
normalize_permutation
.Currently we call into the
pointwise
function to do the pointwise, but we probably want to make atransposed_pointwise
function to handle the case we want to do the preloading, so its handled automatically.