Open jacobhinkle opened 6 days ago
!test --diff
!build
!test --diff
!build
!test
!test
!test
!test
The two remaining "failures" are tests with IDs of IterType::VectorComponent
. These are created by an operation similar to how Broadcast
IDs are created in a BroadcastOp
, so like the broadcast case we should not be introducing a predicate for them either.
Now that I've verified there are no unexpected cases hitting this code path, I'm going to remove the NVF_THROW
and run a codediff. To do that I'll need to disable the test temporarily due to a bug in codediff new or removed tests in the sharded binary tests.
!test --diff
I have re-enabled the test now. To summarize:
lower_utils::getIndexIDs
utility.BroadcastOp
which introduces new IterType::Broadcast
IDs, and for ops like view(DataType dtype)
which can introduce new IterType::VectorComponent
IDs. Neither of these cases require predication so returning early does not change the functionality.!test
Looks like you're still making some changes. Is this ready for review or should we wait?
Looks like you're still making some changes. Is this ready for review or should we wait?
I made a couple changes based on some downstream testing in #3416. Let me run the CI tests again to verify that everything's working then I'll ping you for another review.
!test
@naoyam tests pass on Ampere and I verified the new test passes on Hopper, so I think this is ready for review now.
Could you give some examples of generated code before and after the predicate elimination? Please also include ID exprs of related tensors.
Could you give some examples of generated code before and after the predicate elimination? Please also include ID exprs of related tensors.
I added the example from the added test to the description on this PR. There are no other known examples currently since this PR does not affect predication in the only two other relevant cases:
IterType::Broadcast
domains in the logical domain of the output are similar to the consumer ID iblockIdx.x19{( ceilDiv(i4, 256) )}
in the mma example; they are not mapped to any ID in the input tensor yet they appear in the loop of the consumer and they do map to a producer logical domain (i.e. they are not themselves "loop broadcasts"). This case is covered by the following pre-existing check https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L219-L223view(tv0, DataType::Float)
and tv0
is a ComplexFloat
, the output has a new logical inner dimension of type IterType::VectorComponent
and extent 2. This also does not map to any producer logical ID but it's unlikely this new size-2 dimension would be parallelized so.In the VectorComponent
case, we could skip the early return if consumer_id->isVectorComponent()
just to be sure we do not change any current behavior.
Thanks. I'm concerned about getIndexIDs
since it won't work if tensors have loop-promoted IDs, where we would need to look at the whole fusion to get a full loop promotion map.
I wonder why the existing predicate checker can reason about the other producer, T5
, but not T4
. Which part of the check sets the needs_predicate_
flag before this PR?
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers.
For example, the output of
MmaOp
has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains.This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions.
Detailed example
In the test included in this PR, we have shared memory operand tensors that are scheduled like this
Notice that in
T4_s
that the loop broadcastsbblockIdx.x33{1}
andbS34{256}
are not derived from the logical domain. Instead, they are actually both the products of aSplit
involving an original "loop broadcast", although this is not currently shown infusion->printTransforms()
:In the predicate elimination pass with
T4_s
and producer andT2_l
as consumer, the consumer IDiblockIdx.x19{( ceilDiv(i4, 256) )}
normally would map to a logical broadcast ID inT4_s
, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices thatiblockIdx.x19{( ceilDiv(i4, 256) )}
is not actually used for indexing the producerT4_s
so we do not need to worry about out-of-bounds accesses in this dimension.Without this PR, if we remove the check at https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37 then we generate the following code:
After this PR, the predicate around the
wgmma
call is removed and theassertOnWarpOps
check can be restored.