Open tqchen opened 1 year ago
To be clear, this is proposing to replace the current use of shape_
? I would definitely be in favor of having something that maps more clearly to the annotations in TVMScript, as this proposal discusses.
To be clear, this is proposing to replace the current use of
shape_
? I would definitely be in favor of having something that maps more clearly to the annotations in TVMScript, as this proposal discusses.
Yes, this proposed Pattern will replace the current shape_
, and it would open doors to more useful but erasable information besides shape (termed as "structural value information" in Tianqi's proposal).
I also very like it directly maps to the annotations in TVMScript, and it can remove the duplicated code when we register an Op -- now we have FInferType
and FInferShape
which have overlapping logics (ndim in type/shape), with Pattern, we only need FInferPattern
.
I'd like to clarify something about the examples in the discussion section: Do we want to permit Relax variables to appear in shape annotations? I wrote the draft spec on the belief that we do not permit it and that shape variables can be introduced only in MatchShape
nodes.
NOTE: updated the terminology to StructInfo to avoid confusion with the dataflow pattern lang.
@slyubomirsky thank you for bringing it up. In the particular example, indeed the relax var can appear in the shape in cases where the shape being deduced through an opaque function. It still holds that shape variables are only defined through Match
Okay, we might need to define how that should work.
agree, more broadly, we should clarify what does "match_cast" semantics implies here.
Since I won't be able to be at the community meeting tomorrow, I'll give some of my thoughts in advance in writing (would be happy to discuss further based on what is said at the meeting).
Drawing on my draft rules for shape inference, I think StructInfo could work like this:
match_cast
dynamically checks that the value at run time matches the specified informationmatch_cast
at the beginning and end of the function to dynamically check theseIt is a little harder to decide what to do with annotated StructInfo. In the draft specification, I said that if the compiler cannot statically prove a shape annotation matches the computed shape_
, that the compiler should raise an error and require a dynamic cast. This is an approach we can use for StructInfo, but that does not match the intent of being "best-effort." I think we could use the following policy:
In these cases, I think there should be no run-time semantics for the annotation (i.e., there will be a dynamic check only if there is an explicit match_cast
). Alternatively, we could have a compiler flag to turn all instances of case 2 into implicit match_cast
s (or make that the default).
My only worry is about error reporting if a shape mismatch is detected late in compilation, e.g., after several passes that may have transformed the AST. How would we convey that to the user? Would we expect users to keep track of which passes are applied? For example, it's possible that there is not enough information in the initial program to conclude that a shape mismatched, but after applying function inlining, the compiler is able to conclude that there is a mismatch. It's good to detect errors, but my concern is about how to report them to the user.
I think the relationship of StructInfo to type should be clearly specified as well. I think all StructInfo should be associated with a Relax type.
In general, I like this idea and I would love to spend time whiteboarding out rules for the different kinds of expressions and how we should process the StructInfo. I think we should be careful about what sorts of expressions we permit to appear inside StructInfo and what the scoping rules will be.
Thank you @slyubomirsky for bringing up great points. Agree with your points on how annotation works on arguments and return. I agree with policy especially around 1/3 (where 3 is best effort)
One of things to consider on on policy 2(whether warning should be issued) when we use TVMScript for both storing the intermediate output(where the struct_info are being deduced by the compiler) as an IR. To enable roundtrip capabilities in such cases, the best approach is to not run the deduction to avoid possible additional bindings being created due to general re-deduction, especially around the opaque shape example, and directly take the "assume semantics".
In the case of user provided program where there are only partial annotations. I agree that providing some form of implicit match_cast
(or warning) would make sense. Perhaps we could have some syntax to distinguish the two usecases. Alternatively we always recommend users to use match_cast
, which is more explicit and clearly state that the annotation is assume.
One way to think about best effort compilation error is that if we do not have the rich information, likely they will turn into a runtime error, and in some sense an error at compiler time could be better. Indeed we can think through a bit more about error reporting here. My guess is that the operator context might help.
Also agree about mapping StructInfo into type, the get_static_type()
function provide such a mapping.
Thanks everyone for the proposal and discussion!
At the Relax open dev meeting on Dec 7 (recording, passcode: j$qkF+D2), the community has agreed on bringing the proposed StructInfo
in, and we will proceed on the implementation.
One to to note is that the StructInfo deduction is something that we can continue to refine further.
To help us to quickly get onto the new infra, the first iteration of implementation likely only seek to match the original best effort shape deduction results that we currently have(and not have smarter deductions), so we can have a basis on the new infra for iteration.
Here is a overall sketch guideline on how struct info deduction can work, we do it with the following helper functions:
def unify_to_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo:
"""Find LCA of lhs and rhs"""
def erase_to_well_defined(info: StructInfo,
shape_var_in_scope: List[tir.Var],
var_in_scope: List[Var]):
) -> StructInfo:
"""Erase info to exclude vars that are not in scope"""
unify_to_lca
helps us to find an LCA of two struct info by erasing information.
erase_to_well_defined
is another function introduced to ensure correctness. Consider the following code example
def f(x: R.Tensor[(n, m)]):
k = tir.Var("k", "int64")
v0 = opaque_fn(x)
v1 = match_cast(v0, R.Tensor[(n, k)])
v2 : R.Tensor[(n+1, k+2)] = pad(v1)
return v2
In the above code, the return value y have shape (n + 1, k + 1)
, However, at the level of function signature, only n, m are defined considering the parameters, and k is undefined ones we go outside the scope of the function body and only look at the parameters. In this case:
erase_to_well_defined(R.Tensor[(n+1, k+1)], defined=[n, m])
.R.Tensor(ndim=2)
, which is a more coarse grained struct info that do not contains an undefined var.Let us we look at another example
def f(x: R.Tensor[(n, m)]):
v2 : R.Tensor[(n, m+2)] = pad(x)
return v2
In this case erase_to_well_defined(R.Tensor[(n, m+2)], defined={n, m})
will give us R.Tensor[(n, m+2)]
, because both n amd m can be picked up from the function parameters.
erase_to_well_defined
should be used in scenarios where we are returning values from a scope to outsde, and ensure the struct_info out result is well-defined.
Here is a rough set of deduction rule (note this is a rougn sketch to ensure consistency with shape):
call.op.struct_info
which should be a FuncStructInfo, apply function struct deduction rule.if_node.struct_info = unify_to_lca(
erase_to_well_defined(if_node.then_case, parent_scope_vars),
erase_to_well_defined(if_node.else_case, parent_scope_vars)
)
seq_node.struct_info = erase_to_well_defined(seq_node.body, parent_scope_vars)
ret
is well defined by looking at vars in params and directly use ret.ret_struct_info = erase_to_well_defined(func.body, param_scope_vars)
Note that initial implementation will mainly aims to first reach parity of the original shape deduction with simpler infra, while not realizing the full best effort.
When uncertain, we can call erase_to_well_defined
with no provided vars, or only with vars defined in params. This will give us a good enough case that matches the shape behavior and get the initial infra in place, while leaving room for further refinement of the deduction.
We will do another round of iteration to further strengthen the best effort deduction rules.
One of relax’s design goal is to enable dynamic shapes and program analysis based on dynamic StructInfos. The shape propagation helps us to build effective dynamic shape aware programs.
Shape can be viewed as one kind of “structural value” information — it tells us about some information about the runtime value.
As we start to do more developments, we find a few useful lessons and observations that could help us further evolve the design.
O0: Tracking shapes in advanced structural compositions
In the above program, we are composing X, Y as a tuple. Under such conditions, we would need to define a
shape_
of the tuple to be able to trace the shape of z, however Y do not have shape.O1: Desire for a single place grouping of information
The structural information about a value is spread in between type and shape, making printer and parser needing to be able to collect and values in both side. It would be simpler to have a clear location of grouping.
O2: Extension
As part of our research effort we are considering about extensibility of the system, and we will need to introduce other structural value information besides shape.
Taking these motivations into account, we propose to further formalize structural info deduction as part of relax.
Design
We will introduce a class called
StructInfo
, which is a composite data structure that is going to contain all the necessary structure information deduced for compiler. A StructInfo contains structural information about the corresponding compiled value. See the pesudo code below for a proposed classes.The key design consideration include
match_cast
Example Programs
The program below shows how a StructInfo flows throughout a program. Noteable items include:
The program below shows a possible extension of StructInfo to support sparse computation.
Note that the behavior of sparse addition
z = x+y
is dependent on whether x and y share the sameindptr
andindices
. Having such information available at compile time can helps compilation optimizations.Discussions
The extra information in
Expr.struct_info
does not come for free. Because StructInfo can depend on other values. We should view it as being bundled together with Expr, and consider it carefully when rewriting the code.Consider the above example, if we simply look at the input arguments of calls, we know that there is no dependency from y to z. One possible optimization might involve reordering y into the beginning of the function, or do dead-code to eliminate everything that is not referenced by y.
To track these dependencies, use
all_vars(struct_info)
.It is important to remember that the extra information of StructInfo takes assume semantics rather than static_assert. This means that we will only do best effort checking. To see why we need to take assume semantics. Consider the following function.
Imagine that we want to “recheck” the relation
y: R.Tensor[s0] = opaque_fn(x)
. We will retrigger deduction function ofopaque_fn
. And obtain the following program. The second deduction will generate a fresh shape function calls1
, imagine that shape_func is arbitrary sequence of computations. Then it is impossible to always proof the equivalence.So when we print out TVMScript with already compiler deduced information, we will parse these info back as they are, to ensure round-trip capabilities. To enable user provided information and runtime check, we can always rely on
match_cast
.One can view StructInfo as equivalent to “dependent type”. However a normally type system usually have the follwing properties:
The extra value information in StructInfo StructInfo have the following properties:
As a result, the extra information is more akin to “extra optional analysis information that compiler can take”. We acknowledge the difficulty in doing full proves on extra runtime information. Instead, because all the information are available in runtime (Tensor) values, we use the static type as “safety net”. Static type also remains important and stable acorss dialects such as TIR and relax.
The relation between static type and StructInfo are:
Because of these above considerations, we still believe that it is important to distinguish (static)type and StructInfo, and call them out separately.
Upgrading to the StructInfo
Update to use StructInfo can be mechanical, as we can need to change the shape deduction to StructInfo deductions, match_shape to match_cast. We can also create shape accessor functions that redirects to TensorStructInfo’s shape field to obtain the corresponding symbolic shape. We can choose to always run StructInfo deduction then use
get_static_type
to set the corresponding static types. To enable static type check, get some existing type deduction may not be harmful to avoid StructInfo deduction generate additional bindings. The additional structural information can help us simplify writing parser, printer and propagations across functions and mixed tuple compositions.