Open pxl-th opened 1 week ago
Hi Anton, sorry for the late reply - I've been thinking about the best parallelisation approach for n-dimensional reductions; the lazy way would be to simply run the same reduction steps over each index permutation bar the dims
one over which we're running the reduction - while using the same amount of shared memory as the 1D case, its synchronisation overhead increases with the number of elements. Not great, and it would not be performant for big arrays where the reduced dimension is small.
I am trying to implement a parallelisation approach where each reduction operation done by each thread runs over all index permutations bar dims
(at the expense of more shared memory used), but this requires shared memory whose size is dependent on the specific dimension-wise sizes of the input array - which are runtime values, and hence cannot be used to define @private Int (runtime_length,)
. I need to think a bit more about shared memory sizes...
Still, the 1D and n-dimensional cases would follow separate codepaths for maximum performance - the n-dimensional case simply needs more scaffolding / memory / indices than the 1D one, and no device-to-host copying (which is needed if dims
is not specified).
If you want to, you can forward the call for the 1D case to AK while I wrestle with shared memory. I will make the functions accept N-dimensional arrays which will be operated over linear indices, like Julia Base reductions without specified dims
.
Thanks for working on this!
If you want to use runtime values, you can wrap them in Val
before passing to the kernel:
@kernel function ker(x, ::Val{runtime_length}) where {runtime_length}
m = @private Int (runtime_length,)
...
end
ker(ROCBackend())(x, Val(size(x, 2)))
However, every time you pass a different value it will recompile the kernel, so probably not good if they change a lot.
Base methods, such as
accumulate!
,mapreduce
have support fordims
kwarg. Is there a plan for adding such support here? We can then replace other kernels from AMDGPU/CUDA with AK implementation.