Open jacobhinkle opened 5 months ago
cc @naoyam @zasdfgbnm @kevinstephano
Is Task
a subclass of PolymorphicValue
, or PolymorphicBase
? Can it be a subclass of Expr
, or the other way (Expr
is a subclass of Task
?) Or, can Fusion
be a subclass of Expr
or Task
?
Is
Task
a subclass ofPolymorphicValue
, orPolymorphicBase
?
Thanks for noticing the typo! PolymorphicBase
so that we can easily check what type of Task
we're dealing with, as we do with Expr
.
Can it be a subclass of
Expr
, or the other way (Expr
is a subclass ofTask
?)
I think we could potentially make Expr
a specific kind of Task
. That is a big change though, as much depends on val->definition()
.
Or, can
Fusion
be a subclass ofExpr
orTask
?
Yeah I would think our current notion of "Fusion" would be a type of task, maybe renamed as "Program" since fusion of CUDA kernels only happens within Segment
s. Anyway, aside from naming, I think it would be clean to unify the concepts. The only reason I didn't propose that straight away is that it's such a big change. Same for Expr
, which resides at the other end of the spectrum.
Having nested containers of IR seems critical for so many applications! I used to think we should:
What this could achieve: I thought with this set of behavior then you could use all the tools as they are, but you could arbitrarily view the entire program as multiple hierarchical views that could be linked because they could share the same nodes for inputs/outputs of expressions.
I figured this way we could view segmented Fusions as a view of the original fusion you're segmenting. That way you could more easily traverse from one view to another view.
I also thought this could be useful to have "reversible" transformations, as the new transformation could avoid destroying the original.
Challenges: A few interesting challenges would be to be able to garbage collect (we could start generating a stupid number of nodes and may want to cleanup based on a set of fusions). Or at least cleanup nodes that aren't associated with any fusion that's still alive.
Consistently modifying the IR, if one fusion is changed but has references to another, are those connections updated automatically somehow? Is there any consistency enforced in the hierarchical scheme?
I'd like to facilitate a discussion on how to represent hierarchical programs in our IR. This topic has come up repeatedly in various contexts so there might be an opportunity to introduce a nice abstraction for it.
Current status
Our IR currently has
IrContainer
which is the base class forFusion
.IrContainer
owns all theStatement
s in the "fusion", and theFusion
class mostly manages inputs and outputs; this includes updatingval->uses()
when expressions are registered, since this can change the reachability of each statement. Any other child classes ofIrContainer
are themselves children ofFusion
(namely,kir::Kernel
).Expr
s are the simplest subfusionsIn some sense, the core function of a
Fusion
(managing inputs/outputs) overlaps that ofExpr
. Namely, both represent some computation that takes in a collection ofVal
s and outputs some otherVal
s. This is the only sense in which we currently support "sub-fusions"; i.e. since aFusion
contains multipleExpr
s, we can think of it as a hierarchy of subfusions.The purpose of this proposal is to represent subfusions at intermediate granularities between "entire fusion" and "single expression".
Use cases
Partitioning
Fusion
s into tasksPartitioning the graph and executing segments separately from one another. This is important for
SegmentedFusion
which tracks collections of edges and groups of expressions representing segments.During segmentation, to analyze each segment we have a guard class that swaps the fusion inputs and outputs so that analysis can use our standard traversal utilities. Instead, it might be useful to represent segments during segmentation as subfusions that overlay the base fusion. This idea could be generalized to handle task parallelism
ExpressionEvaluator
patterns (to replace composite IR nodes)Recently, we introduced IR nodes for matmul patterns (#2175, #2240, #2294). These have greatly simplified our ATen fallback evaluation mode, replacing complex pattern matching code. However, fusion of these patterns still requires decomposition (#2236) into smaller primitives. If we had a working subfusion system, we might do both at once at program definition. For example, our
linear
function might createBroadcastOp
nodes as well as aMmaOp
node, add bias withBinaryOp
then cast to reduced precision usingUnaryOp
. All these ops would then be grouped into a subfusion that is identified as a "Linear" pattern so that evaluation of its output could easily be identified and computed using ATen.Lambdas for generalizing reduction/scan
When discussing #622 recently, @naoyam pointed out that representing lambdas directly in the tensor IR tends to clutter and complicate the fusion (see #2307). With the right abstraction, we could potentially maintain a lambda function as a subfusion that is disconnected from the main Fusion inputs and outputs, so that we would no longer need special
IterType
s and we could use a single op likeFoldOp
to represent a reduction or scan. Note that we could still schedule the lambda as desired but we would not need to take as much care with inlining and allocation concerns if it is disconnected and represented in this way.Possible approach
One approach would be to keep the current system intact, but add a new
Task
type like thisFusion
could own theseTask
s just like it currently ownsStatement
s.Task
could potentially be made a subclass ofStatement
to facilitate cloning.We could generalize
Val::isFusionInput()
andVal::isFusionOutput()
with the following:Does this address the use cases above?
Task
at that time. TheExpressionEvaluator
could use this to check for supported fallback patterns by first checking for supportedval->taskDefinitions()
then checkingval->definition()
.class Segment : public Task {};
so that creating and modifying segments doesn't alter the fusion but just modifiesTask
s. Then checking whether aVal
is a segmentation edge just means finding a task inval->taskDefinitions()
orval->taskUses()
thatisA<Segment>()
.Task
to use as a lambda, then just attach it to ourFoldOp
orScanOp
as an attribute.