Closed rdspring1 closed 1 week ago
!test
how will we handle partial vectorization?
Do you mean when the tensor is not 16B aligned? You can overcopy with TMA, cp.async, or regular LDG + STS.
how will we handle partial vectorization?
Do you mean when the tensor is not 16B aligned? You can overcopy with TMA, cp.async, or regular LDG + STS.
Yeah exactly. So if we had K=60 and that is the inner dimension of each of the operands, in the Ampere scheduler we need to handle them differently when we generate the kernel since we can only do 4-element reads for the cp.async call then in stead of 8-element reads. But I don't see where that kind of alignment analysis comes in when using TMA; will TMA handle misaligned boxes dynamically using the same compiled kernel as for fully-aligned inputs?
EDIT: is this computed on the host side in the TMA descriptor?
TMA should automatically handle the case when K=60 by filling the out-of-bounds accesses. If the tensor is not 16B aligned, TMA will fail and you need to use regular LDG + STS accesses.
!test
Looks like you just need to guard AmpereMatmulBroadcastBatch
. I noticed I needed this in #3278 but I was too lazy to merge that upstream to this PR for you. https://github.com/NVIDIA/Fuser/pull/3278/files#diff-64fc4e7bfbc5b9f95ac3dc5823bd99b683b048926805c13310ce6a8ef8032289R147-R148
!test
This PR modifies
schedulePrologues
to use TMA loads to move mma operands to shared memory. Stacked on https://github.com/NVIDIA/Fuser/pull/3324 and https://github.com/NVIDIA/Fuser/pull/3310.Details
CpAsyncBulkTensorTile
LoadStoreOp.LdMatrix
operation with basic set.scheduleOperandSmemStores
to apply swizzling to avoid bank conflicts.swizzleSharedMemory
by moving the analysis component to a separate function namedanalyzeSwizzleSharedMemory
.tmaSwizzleSharedMemory
function that usesanalyzeSwizzleSharedMemory
and then finds the appropriate tma swizzle format.