Open magcius opened 1 year ago
Don't forget about Thread Group PrefixSum and PrefixCountBits. Those are immensely useful and would be great to have as primitives that are guaranteed to be implemented efficiently for all the different wave sizes out there.
I think this would be a great addition personally. 👍🏻
One possible "issue" that would need to be specified is how these ops would interact with synchronization primitives. i.e. After a ThreadGroupActiveSum
call, should we stipulate that a GroupMemoryBarrierWithGroupSync
is needed before the result is consumed? If the synchronization is implicit, would it be understood that sequenced calls to various ThreadGroup*
operations insert barriers as necessary? I'm inclined to suggest starting with an explicit approach as opposed to automatic barrier insertion (which would rely on barrier coalescing downstream in the PSO compilation pipeline).
One more addition that would be great to see on the list would be a group equivalent to WaveActiveMatch
, presumably named ThreadGroupActiveMatch
.
My personal vote would be for automatic barrier insertion, as far as how the various calls work, with the implementation then free to strip redundant barriers that may result from serial calls to various ThreadGroup functions. As the degree of internal synchronization needed to implement one of these functions is likely implementation specific it seems a little strange to demand a faux external synchronization primitive (other than perhaps to imply to the user that such syncs may internally take place). It also means that if an IHV pursues some mechanism that allows direct support for ThreadGroupXXX
functionality that internally doesn't require groupshared syncs, they aren't artificially penalized for it (the most obvious case is where threadgroup size maps directly to wavesize, and such barriers would contribute nothing).
My personal vote would be for automatic barrier insertion
I'm not entirely opposed unless it means that it's much harder for drivers/IHVs to provide the functionality. I suppose my inclination is weighted based more on wanting the feature at all and trying to make things easier for the implementer, but as a user, for sure I would prefer the synchronization to be implicit.
Is your feature request related to a problem? Please describe.
It can be difficult when trying to make subgroups work across a variety of different hardware with different wave sizes.
Describe the solution you'd like
For a number of operations, like min and max, it would be nice if I could take their values across the entire thread group, rather than just across a single subgroup. For cases where the wave size <= the workgroup size, the driver could implement this with a wave operation.
Not all wave operations have a natural or clean mapping to thread group operations. To start with, I might expect a limited number of thread group operations, like:
* ThreadGroupActiveCountBits() * ThreadGroupActiveSum() * ThreadGroupActiveBitAnd() * ThreadGroupActiveBitOr() * ThreadGroupActiveBitXor() * ThreadGroupActiveMin() * ThreadGroupActiveMax()
Describe alternatives you've considered
This is already possible today by using groupshared memory and memory operations, but allocating groupshared memory isn't free, and while some drivers might optimize away groupshared, it would be nice to get a guarantee.
I think it would also improve the beginner experience, who can care more about workgroups and less about subgroups. It feels odd that this is missing.
Additional context
This dovetails a bit with #83, which is about improving the groupshared mechanism to allow scoped groupshared.
While SPIR-V does, in theory, does have group and subgroup instructions that support scoping to the Workgroup level, this is currently not supported in the Vulkan profile.
My general take is that its supposed to be in a User-Space library like mine, and we already have implementations of what you ask for in progress: https://github.com/Devsh-Graphics-Programming/Nabla/pull/519
Also @magcius as you love to point out, subgroup != workgroup, so unless you pressure the DX and Vulkan specs to ensure that subgroups cannot straddle workgroups (or rely on UB) the official implementation of the above primitives would not be able to leverage subgroup operations (which are now ubiquitous on desktop) lest it incurs UB.
Not only would you get a slower implementation (every iteration/step would be a blelloch, creating 2x the work, instead of a Stone-Kogge intra subgroup) without UB, but you'd also get muuch lower occupancy because you'd need to reserve far more groupshared memory (where you'd use subgroup ops to shuffle during log2(subgroup_size)
iterations of the Upsweep or Downsweep in the Blelloch Scan).
Note that unless you want meh performance, it would be best to template on the subgroup size you control, and I guess that for this to be built-into HLSL it would have to be compileable up-front.
One possible "issue" that would need to be specified is how these ops would interact with synchronization primitives. i.e. After a ThreadGroupActiveSum call, should we stipulate that a GroupMemoryBarrierWithGroupSync is needed before the result is consumed? If the synchronization is implicit, would it be understood that sequenced calls to various ThreadGroup* operations insert barriers as necessary?
Excellent point, with the implementation being in "user space" you'd get to choose if this happens or not, what scratch groupshared memory can be aliased, etc.
I agree with @yuriy-odonnell-epic that this should include prefix sums, and I'd also like to throw in ThreadGroupAtomic*
which executes an atomic once per thread group for the sake of completeness and orthogonality. This would compile down to the usual if (get_thread_id() == 0) ..
but without having to pollute an input with a thread group ID etc., which may very well no longer be necessary if you have thread-group wide operations.
Is your feature request related to a problem? Please describe.
It can be difficult when trying to make subgroups work across a variety of different hardware with different wave sizes.
Describe the solution you'd like
For a number of operations, like min and max, it would be nice if I could take their values across the entire thread group, rather than just across a single subgroup. For cases where the wave size <= the workgroup size, the driver could implement this with a wave operation.
Not all wave operations have a natural or clean mapping to thread group operations. To start with, I might expect a limited number of thread group operations, like:
Describe alternatives you've considered
This is already possible today by using groupshared memory and memory operations, but allocating groupshared memory isn't free, and while some drivers might optimize away groupshared, it would be nice to get a guarantee.
I think it would also improve the beginner experience, who can care more about workgroups and less about subgroups. It feels odd that this is missing.
Additional context
This dovetails a bit with #83, which is about improving the groupshared mechanism to allow scoped groupshared.
While SPIR-V does, in theory, does have group and subgroup instructions that support scoping to the Workgroup level, this is currently not supported in the Vulkan profile.