ROCm / rocMLIR

124 stars 40 forks source link

[DO NOT SQUASH] Handle non-zero-preserving input fusions, make read_into track validity #1629

Open krzysz00 opened 3 weeks ago

krzysz00 commented 3 weeks ago

Apologies for the mega-commit. (this one also includes a cherry-pick from upstream)

Previously, if you input-fused an elementwise funciton like (x) => (x + 4), this would cause issues, because the 4 would be added both to elements of the input value but also to padding values introduces during tiling, which would cause incorrect results.

To fix this, we:

  1. Make threadwise_read_into optionally return a boolean vector of validities, recording whether or not an element in the register tile actually came from memory or not
  2. Give threadwise_read_into a dynamic validities argument, where the validity data from the coordinate transforms is ANDed with the validity vectors from those arguments to allow applying validitiies that don't come from coordinate transforms.
  3. Update all the threadwise_read_into operations that can have input fusions on them to generate a validity record.
  4. It a linalg.generic is being tiled by a threadwise_read_into and that read records validity (that is, it came from an input fusion) and the result of the elementwise function on all-zero inputs isn't zero, we propagate the validities.

That is, we make the threadwise_read_into ops that read each input to the generic record their validities, and then do a register -> register threadwise_read_into (which would ordinarily be a memcpy()) with dynamic validities taken from those reads to re-apply the 0 mask after the generic runs.

Note: While I was here, I noticed that non-xdlops gemms were producing two sets of theradwise_read_into ops, so I fixed that, which broke a few tests

codecov[bot] commented 3 weeks ago

Codecov Report

Attention: Patch coverage is 83.67347% with 32 lines in your changes missing coverage. Please review.

Project coverage is 78.00%. Comparing base (9ef8698) to head (0435394). Report is 4 commits behind head on develop.

Files with missing lines Patch % Lines
mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp 82.97% 7 Missing and 9 partials :warning:
mlir/lib/Dialect/Rock/IR/RockDialect.cpp 41.17% 6 Missing and 4 partials :warning:
...Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp 90.19% 4 Missing and 1 partial :warning:
...ialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp 96.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #1629 +/- ## =========================================== - Coverage 78.07% 78.00% -0.07% =========================================== Files 98 98 Lines 26252 26369 +117 Branches 3731 3752 +21 =========================================== + Hits 20495 20570 +75 - Misses 4291 4317 +26 - Partials 1466 1482 +16 ``` | [Flag](https://app.codecov.io/gh/ROCm/rocMLIR/pull/1629/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ROCm) | Coverage Δ | | |---|---|---| | [mfma](https://app.codecov.io/gh/ROCm/rocMLIR/pull/1629/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ROCm) | `78.00% <83.67%> (-0.07%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ROCm#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.