Open benvanik opened 1 week ago
Hey @benvanik I would like to work on this. Could you assign this to me?
I would like to go with the second approach that you pointed out :
Another approach would be to insert a util.unreachable op in the default region of the scf.index_switch and then have a CFG cleanup that walks back from any block with that to a predecessor and fixes up the terminator to make it go the other direction. That also would be applicable to a lot of other parts of the codebase and is a good lowering target for assume ops/int range analysis.
I will implement this like a match and rewrite pattern in similar lines of IndexSwitchToIfPattern
and MergeIndexSwitchPattern
in the commit here.
I will create some synthetic test cases like the ones mentioned in the issue and the commit referenced above, but if you have a test case from an ML model that you would like me to look at, feel free to include them.
When multi-targeting on a single device each dispatch gets a switch on the chosen executable variant. We know when we emit the
scf.index_switch
(which later becomescf.switch
) that at load-time we verified at least one of the cases will be reached. There's currently no way to communicate this to those ops, though, so we end up generating some pretty terrible code.This is what 99.9% of the switches look like (as the workgroup counts and parameters often don't differ):
This lowers into cf:
Which eventually gets folded due to the duplication:
That
cf.switch
default case is not reachable and so the op is longer required and we should be folding it away:Which will go away in block simplification:
There's a few approaches we could take. One would be to change the scf/cf ops to avoid the need for defaults (perhaps only in cases where there's no returns) - we'd just emit the op with no default region. That'd help all code emitting the ops that know it when they emit it but has caveats if the information only becomes available later on (via const eval, IPO, specialization, etc).
Another approach would be to insert a
util.unreachable
op in the default region of thescf.index_switch
and then have a CFG cleanup that walks back from any block with that to a predecessor and fixes up the terminator to make it go the other direction. That also would be applicable to a lot of other parts of the codebase and is a good lowering target for assume ops/int range analysis.If we did
util.unreachable
we could also support optimizations with the newutil.assume.int
ops - in this case for example if int range analysis or something else added an assume op:a pass could insert the
util.unreachable
:And then that'll get removed later on.