NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

[WIP] Schedule Hopper MMA without input broadcasts #3406

Open jacobhinkle opened 1 week ago

jacobhinkle commented 1 week ago

Stacked on #3410, #3414, and #3416

This uses #3391. I modified the HSH_NT_128BSwizzle test by removing the broadcast dimensions in the inputs. The goal is to schedule this fusion successfully on Hopper without incurring any wasteful smem->register round trips before the mma instruction.

Approach

I have not modified any of the utility functions. Instead, I made the following changes:

  1. Modify the inputs to remove the broadcast dimensions
  2. Define an AxisMapping to use these 2D inputs in fusedMultiplySum
  3. Do a loop-domain broadcast using tv->broadcast(pos) for the smem operands.
  4. Replace the 2-way transform propagation from the mma result tv2 with a 1-way propagation to the output tv3.
  5. Manually transform the smem operands to mimic what is done to tv2, but with a different dimension ordering.

This way, the broadcast dimensions are still present in the operands even though they don't exist in their root or logical domains.

Steps 4 and 5 are necessary because transform propagation doesn't seem to handle the loop-domain broadcast seamlessly yet.

Status

Currently, this test fails to inline the smem operands tv0c and tv1c beyond the first dimension, causing a failure when we try and set circular buffering. Removing the tv*c->circularBuffer() calls just leads us to an error building ComputeAtMap that complains that the disjoint set for the manually created broadcast ID is requested but does not exist. inlineMost chooses position 1 instead of 3 for the ca_pos. I tried to manually call computeAt to set this to 3 but hit errors. I might need to change my approach or figure out how to inline these manually created broadcast dimensions