I noticed that our sum reduction kernel was awkwardly written. The enhancement to size1 > 1 was done wrong. Each work group consumes shared memory equal to work_group_size * size1 but that is unjustified and will even cause shared memory overflow when size1 is too big.
The root cause it that there's a loop on n_rows when size1 > 1, but it was only necessary to repeatedly apply the same kernel to the other rows. That enables a better allocation of shared memory (constant w.r.t size1) and also reads much better overall since it doesn't require to abstract the operator anymore.
I noticed that our sum reduction kernel was awkwardly written. The enhancement to size1 > 1 was done wrong. Each work group consumes shared memory equal to work_group_size * size1 but that is unjustified and will even cause shared memory overflow when size1 is too big.
The root cause it that there's a loop on
n_rows
when size1 > 1, but it was only necessary to repeatedly apply the same kernel to the other rows. That enables a better allocation of shared memory (constant w.r.t size1) and also reads much better overall since it doesn't require to abstract the operator anymore.