anicusan / AcceleratedKernels.jl

Cross-architecture parallel algorithms for Julia's GPU backends, from a unified KernelAbstractions.jl codebase. Targets Intel oneAPI, AMD ROCm, Apple Metal, Nvidia CUDA.
https://anicusan.github.io/AcceleratedKernels.jl/
MIT License
81 stars 1 forks source link

Support for `dims` kwarg #6

Open pxl-th opened 1 week ago

pxl-th commented 1 week ago

Base methods, such as accumulate!, mapreduce have support for dims kwarg. Is there a plan for adding such support here? We can then replace other kernels from AMDGPU/CUDA with AK implementation.

anicusan commented 1 week ago

Hi Anton, sorry for the late reply - I've been thinking about the best parallelisation approach for n-dimensional reductions; the lazy way would be to simply run the same reduction steps over each index permutation bar the dims one over which we're running the reduction - while using the same amount of shared memory as the 1D case, its synchronisation overhead increases with the number of elements. Not great, and it would not be performant for big arrays where the reduced dimension is small.

I am trying to implement a parallelisation approach where each reduction operation done by each thread runs over all index permutations bar dims (at the expense of more shared memory used), but this requires shared memory whose size is dependent on the specific dimension-wise sizes of the input array - which are runtime values, and hence cannot be used to define @private Int (runtime_length,). I need to think a bit more about shared memory sizes...

Still, the 1D and n-dimensional cases would follow separate codepaths for maximum performance - the n-dimensional case simply needs more scaffolding / memory / indices than the 1D one, and no device-to-host copying (which is needed if dims is not specified).

If you want to, you can forward the call for the 1D case to AK while I wrestle with shared memory. I will make the functions accept N-dimensional arrays which will be operated over linear indices, like Julia Base reductions without specified dims.

pxl-th commented 1 week ago

Thanks for working on this!

If you want to use runtime values, you can wrap them in Val before passing to the kernel:

@kernel function ker(x, ::Val{runtime_length}) where {runtime_length}
    m = @private Int (runtime_length,)
    ...
end

ker(ROCBackend())(x, Val(size(x, 2)))

However, every time you pass a different value it will recompile the kernel, so probably not good if they change a lot.