This uses #3391. I modified the HSH_NT_128BSwizzle test by removing the broadcast dimensions in the inputs. The goal is to schedule this fusion successfully on Hopper without incurring any wasteful smem->register round trips before the mma instruction.
Approach
I have not modified any of the utility functions. Instead, I made the following changes:
Modify the inputs to remove the broadcast dimensions
Define an AxisMapping to use these 2D inputs in fusedMultiplySum
Do a loop-domain broadcast using tv->broadcast(pos) for the smem operands.
Replace the 2-way transform propagation from the mma result tv2 with a 1-way propagation to the output tv3.
Manually transform the smem operands to mimic what is done to tv2, but with a different dimension ordering.
This way, the broadcast dimensions are still present in the operands even though they don't exist in their root or logical domains.
Steps 4 and 5 are necessary because transform propagation doesn't seem to handle the loop-domain broadcast seamlessly yet.
Status
Currently, this test fails to inline the smem operands tv0c and tv1c beyond the first dimension, causing a failure when we try and set circular buffering. Removing the tv*c->circularBuffer() calls just leads us to an error building ComputeAtMap that complains that the disjoint set for the manually created broadcast ID is requested but does not exist. inlineMost chooses position 1 instead of 3 for the ca_pos. I tried to manually call computeAt to set this to 3 but hit errors. I might need to change my approach or figure out how to inline these manually created broadcast dimensions
Stacked on #3410, #3414, and #3416
This uses #3391. I modified the
HSH_NT_128BSwizzle
test by removing the broadcast dimensions in the inputs. The goal is to schedule this fusion successfully on Hopper without incurring any wasteful smem->register round trips before the mma instruction.Approach
I have not modified any of the utility functions. Instead, I made the following changes:
fusedMultiplySum
tv->broadcast(pos)
for the smem operands.tv2
with a 1-way propagation to the outputtv3
.tv2
, but with a different dimension ordering.This way, the broadcast dimensions are still present in the operands even though they don't exist in their root or logical domains.
Steps 4 and 5 are necessary because transform propagation doesn't seem to handle the loop-domain broadcast seamlessly yet.
Status
Currently, this test fails to inline the smem operands
tv0c
andtv1c
beyond the first dimension, causing a failure when we try and set circular buffering. Removing thetv*c->circularBuffer()
calls just leads us to an error buildingComputeAtMap
that complains that the disjoint set for the manually created broadcast ID is requested but does not exist.inlineMost
chooses position 1 instead of 3 for the ca_pos. I tried to manually callcomputeAt
to set this to 3 but hit errors. I might need to change my approach or figure out how to inline these manually created broadcast dimensions