Open liqiangxl opened 1 year ago
Isn't it necessary to check the combined scheduler before the inner and outer schedulers? Otherwise, would it be ever used?
Step-1, refactor ScheduleHeuristic::Persistent to ScheduleHeuristic::xPersistent, x is {Inner, Outer, InnerOuter}
(1) Create 3 schedulers, each scheduler will call their corresponding getxPersistentHeuristics
and schedulePersistentKernel
. Take innerPersistent for example:
(2) Create a namespace PersistentSchedulerChecker
stores common functions used by canScheduleCompileTime
and canScheduleRunTime
defined in these 3 schedulers.
Step-2, Define canScheduleCompileTime
and canScheduleRunTime
for each scheduler.
(1) Define canScheduleCompileTime
for each scheduler.
InnerPersistent and OuterPersistent are doing the same compile time checks. They are calling the same wrapped function commonCompileTimeCheck
which is defined in PersistentSchedulerChecker
using utility functions.
InnerOuterPersistent needs additional checks. It is implemented using utility functions in PersistentSchedulerChecker
and normalization_scheduler_utils
.
(2) Define canScheduleRunTime
for each scheduler.
Similar to canScheduleCompileTime
.
Isn't it necessary to check the combined scheduler before the inner and outer schedulers? Otherwise, would it be ever used?
There is a check of reduction type. Inner scheduler only works with fusion with only inner reductions, outer scheduler only works with fusion with only outer reductions.
bool checkReductionType(
const std::vector<TensorView*>& reduction_tvs,
ScheduleHeuristic heuristic) {
auto reduction_type =
reduction_scheduler_utils::getReductionType(reduction_tvs);
auto expected_type =
reduction_scheduler_utils::mapScheduleHeuristicToReductionType(heuristic);
if (reduction_type != expected_type) {
scheduler_debug_utils::canScheduleRejectReason(
heuristic, "ReductionType and heuristic doesn't match.");
return false;
}
return true;
}
where
ReductionType mapScheduleHeuristicToReductionType(ScheduleHeuristic sh) {
switch (sh) {
case ScheduleHeuristic::InnerPersistent:
return ReductionType::Inner;
case ScheduleHeuristic::OuterPersistent:
return ReductionType::Outer;
case ScheduleHeuristic::InnerOuterPersistent:
return ReductionType::InnerOuter;
default:
return ReductionType::None;
}
}
and
ReductionType getReductionType(const std::vector<TensorView*>& reduction_tvs) {
bool is_inner_reduction = false;
bool is_outer_reduction = false;
for (auto tv : reduction_tvs) {
if (scheduler_utils::isFastestDimReduction(tv)) {
is_inner_reduction = true;
} else {
is_outer_reduction = true;
}
}
if (is_inner_reduction && is_outer_reduction) {
return ReductionType::InnerOuter;
} else if (is_inner_reduction) {
return ReductionType::Inner;
} else if (is_outer_reduction) {
return ReductionType::Outer;
} else {
return ReductionType::None;
}
}
Step-3, Define get{Inner, Outer, InnerOuter}PersistentHeuristics
(1) InnerPersistent and OuterPersistent are using the same paras to call {inner,outer}PersistentHeuristic
. They are calling the same wrapped function PersistentHeuristicsHelper::getCommonHeuristicParams
.
(2) InnerOuterPersistent needs special calculation of buffer size and intermediate tensor data type.
Step-4, Define schedule{Inner, Outer, InnerOuter}PersistentKernel (1) InnerPersistent and OuterPersistent share the same process (2) InnerOuter is already a standalone function.
Step-5: refactor scheduleReductionTV****
scheduleReductionTV
is used by all kinds of reductions/normalizations, can be split into:
(1) scheduleInnerPersistentTV
for inner persistent kernel.
(2) scheduleOuterGridPersistentTV
for outer grid persistent kernel.
(3) scheduleOuterBlockPersistentTV
for outer block persistent kernel.
(4) scheduleInnerOuterPersistentTV
for innerOuter persistent kernel.
(5) scheduleReductionTV
for reduction
Step-6: refactor ReductionParams****
(4) Corresponding to scheduleReductionTV
, ReductionParams
is also used by all kinds of reductions/normalizations can also be split into 5 classes.
The persistent scheduler
ScheduleHeuristic::Persistent
evolves into a set of interconnected yet notably distinct strategies: innerPersistent, outerPersistent and innerOuterPersistent. It is used to schedule fusions with both reductions and persistent buffers. Fusions with reduction that don't require persistent buffers are scheduled byScheduleHeuristic::Reduction
which is not a topic of this refactor although the method used here applies.The function
checkCanSchedule
utilizes bothcanScheduleCompileTime
andcanScheduleRunTime
to determine the feasibility of scheduling a fusion. If deemed possible, the heuristic gets established bygetPersistentHeuristics
, leading to the fusion's scheduling viaschedulePersistentKernel
.Within these functions, they ascertain the type of reduction (inner, outer, or innerOuter) and apply various strategies based on the identified reduction type. Instead of cramming these diverse heuristics into one comprehensive scheduler, each could operate as an independent scheduler:
By doing this, the complexity of the persistent scheduler would be significantly reduced. A potential downside to this approach is the thrice-checking of the reduction type, specifically when the innerOuterPersistentScheduler proves to be the appropriate choice. Nonetheless, this can be sidestepped by harnessing specific cues from the fusion and directly dispatch to the appropriate scheduler instead of looping through every scheduler.