NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Use SimplifyingIrBuilder for split and merge extents #3309

Closed naoyam closed 1 week ago

naoyam commented 3 weeks ago

Currently, when we split an iter domain, we get ceilDiv(i0, s), where i0 is the extent of the input iter domain and s is the split factor. With this PR, it may no longer be always ceilDiv when it's simple enough that SimplyfingIrBuilder can understand. One common example is split by 1, which is often done for unswitch. We would just get i0 instead of ceilDiv(i0, 1). Same for merge.

This should most likely have no impact on actual kernel performances as many of them are simplified and hoisted anyway. However, printMath could be greatly simplified especially when resize is used with static shapes.

On the other hand, as mentioned https://github.com/NVIDIA/Fuser/pull/3309#discussion_r1824753545, we would not be able to infer what ID operations were used by just looking at the extent. I personally that wouldn't be too concerning, and looks like @jacobhinkle and @zasdfgbnm also have the same opinion.

naoyam commented 3 weeks ago

!build

naoyam commented 2 weeks ago

Part of the failures were because SimplifyingIrBuilder was only used in IterDomain::split and IterDomain::merge, but not in transform_rfactor and transform_replay, where Split and Merge are directly created without using the IterDomain functions. This meant that when we have a split of i0 by 1, for example, sometimes it's simplified to just i0 but not always. Since they are still exactly mapped, an exact group has something like {i0, ceilDiv(i0, 1), ...}. Suppose i1 = ceilDiv(i0, 1), we get an equality relationship of i0 == i1, which is used in replaceSymbolicSizes. In this case, it can result in a recursive definition as follows. The definition of i1,

i1 = ceilDiv(i0, 1)

is going to be examined by the mutator. While we disable replacing outputs, the use of i0 in ceilDiv may be replaced by i1 since the replacement map may have i0 -> i1. As a result, the definition of i1 becomes:

i1 = ceilDiv(i1, 1)

To avoid this, vals should never be replaced with dependent vals. We should probably add some assertions.

naoyam commented 2 weeks ago

!test

naoyam commented 2 weeks ago

Stacked on top of #3344, although nothing depends on #3344.

naoyam commented 2 weeks ago

Forgot to update reference code for checking generated code. Will ping once done.

naoyam commented 2 weeks ago

!test

naoyam commented 2 weeks ago

!test

naoyam commented 2 weeks ago

!test

naoyam commented 1 week ago

Ready to review again. @jacobhinkle @zasdfgbnm

naoyam commented 1 week ago

!test