Open jacobhinkle opened 1 year ago
The scheduler that accepted the original segment would
Do you want to complete this sentence?
I agree with you that "replace before segmentation" is too hard -- it's hard to write a shouldRedefine
function that reasons about how downstream schedulers (e.g. whether Persistent accepts the two segmented Fusions
) will treat them. Even if we manage to write one, it's probably going to be as complicated as the actual segmentation.
Regarding "replace after segmentation", can we undo a segmentation instead of worrying about merging segmented fusions? It for sure isn't as comprehensive but seems to be enough to cover the two use cases you listed?
This is a proposal for a new pass that occurs just after the pre-segmentation optimization passes (which do not use runtime info), but before or possibly during segmentation. The primary uses initially will be to translate Welfords into two-pass varmean, and to translate matmuls into BMMs in order to implement two-kernel split-K.
Motivation
Segmentation currently includes a pass called
TranslateApplicableWelford
. The purpose of that class is to translate Welford ops that are present in the fusion into "two-pass variance" ops; i.e. compute the mean, subtract, square, sum. This is advantageous due to the lower instruction count compared to Welford, in cases where the inputs are held in smem. That pass currently works by copying the fusion, doing the translation, then checking whether the persistent heuristic can schedule the resulting fusion.For two-kernel split-K, we need to rfactor the K dimension and insert a
segmentSet
between the mma and the subsequent reduction.In each case, a heuristic is responsible for deciding whether to perform the translation: the reduction scheduler for TranslateApplicableWelford and the matmul scheduler for two-kernel split-K.
Possible approaches
Replace before segmentation
We could add the following methods to
SchedulerEntry
:bool shouldRedefine(const std::vector<Expr*>& exprs)
void redefine(std::vector<Expr*>& exprs)
and associated dispatch functions inregistry.cpp
.At the beginning of
segment()
we can then propose redefinitions and perform them, just as we currently do withproposeHeuristics
.Possible issues with this approach
Since we are redefining the graph, subsequent redefinitions may depend on previous ones. This is not yet the case since these two proposed passes do not directly interact, but it could in the future. In that case we'll need to be careful with ordering and we'll probably need to loop until saturation.
As I've stated it so far, this would apply only to the complete Fusion and would occur before segmentation. However, for simplicity's sake the heuristics need to look at individual segments. There we hit a chicken and egg problem: we need to redefine the Fusion to alter how it gets segmented, but we might need to look at the segmented fusion to make that decision.
Replace after segmentation
Instead of (or in addition to?) changing the definition of the Fusion before segmentation, we could instead segment just as we normally do (removing the current translateApplicableWelford), but provide a mechanism to resegment an already-accepted segment. The scheduler that accepted the original segment would
Possible issues with this approach
It might not be easy to merge segments in this approach. This could be lead to oversegmentation. Consider a case where the original Fusion is segmented into A -> B. Then the scheduler for segment A redefines that segment and requests it to be resegmented, resulting in A -> C -> B. It might be that the new segment C could be fused with B, but we haven't asked to resegment B. Does this mean we should merge all neighboring segments during resegmentation? This example affects only one neighbor but in general this might mean we need to resegment the entire Fusion...