ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
181 stars 82 forks source link

Preload transposed inputs in LDS for pointwise kernels #3172

Open pfultz2 opened 2 months ago

pfultz2 commented 2 months ago

To improve performance for transpose kernels we should load the transposed inputs into LDS directly, and then read from LDS instead. We have function like preload_copy which will do this preloading with a vectorize copy, but this will preload the entire tensor.

Instead we should only preload a slice. Ideally this slice will be across the transposed axis, so in the case of NHWC we would slice across the C axis(ie axis 1) which will give us slices of CHW. The slice may contain more than one channel for the case where HW is very small as we want to make sure each workgroup has engough work to do.

The preload_copy function takes a boolean which enables preloading for the tensor. It uses a lambda callback with the new tensor so the LDS can be statically allocated and its lifetime will live as long as the tensor view. We should be able to reuse this function to do the preloading, we just need to do the slicing:

template <index_int Axis, index_int N, class T>
__device__ auto auto_preload_tranposed(index idx, T x)
{
    return [=](auto f) {
        auto y = make_slice<Axis, N>(x);
        preload_copy<should_preload_transposed_slice<T, Axis>()>(idx, y)(f);
    };
}

Functions like make_slice and should_preload_transposed_slice will need to be implemented. And this could be extended for variadic parameters by using recursion:

template <index_int Axis, index_int N, class T, class... Ts>
__device__ auto auto_preload_tranposed(index idx, T x, Ts... xs)
{
    return [=](auto f) {
        auto_preload_tranposed(idx, x)([=](auto y) {
            auto_preload_tranposed(idx, xs...)([=](auto... ys) {
                f(y, ys...);
            }); 
        });
    };
}

Currently, preload_copy will only work with packed tensors because it uses the original tensor shape to read the LDS buffer. We can make packed be a restriction at first, but we could get unpacked support working by tweak the preload_copy in the future.

Since we are reading a slice per workgroup the global calculation will need to be updated.

Since we are preloading transposed inputs we should prefer that the inputs are transposed rather than the outputs. This might require some adjustment to the permutation that we pass to normalize_permutation.

Currently we call into the pointwise function to do the pointwise, but we probably want to make a transposed_pointwise function to handle the case we want to do the preloading, so its handled automatically.

pfultz2 commented 2 months ago

And this could be extended for variadic parameters by using recursion:

Actually this is not needed. auto_preload will do this for multiple parameters, and its setup as an arg transformer. We should probably make the slicing an arg transformer as well so we can just pass it the current pointwise function and we dont need another pointwise overload for doing preloading(we can also then resuse this for other kernels if we want).