tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[AST] CallNode with `sinfo_args` #356

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

Background

At this moment, relax::CallNode has a field named type_args, which is an array of Type.

In our current design, this field is used to indicate the return value type for certain ops, which is necessary for type inference (or structure info inference) of those ops at compile-time, so that we would be better to do optimization. Specifically, these ops are


Proposal: switch to sinfo_args

Since we have refactored our infrastructure with structure info, in compile-time we no longer do shape deduction and type deduction separately for CallNodes. Instead, we deduce structure info only, which is a unification of previous shape and type deduction. So we propose to iterate type_args to sinfo_args, which seeks to simplify the structure info deduction of these ops.

The sinfo_args is proposed as an Array<StructInfo> field of CallNode, and is designed possible to be non-empty only for intrinsic ops in Relax (like the three example cases).

There are three supportive facts behind:

S1. Richer compile-time info from StructInfo

Structure info contains possibly more information than static types.

S2. Current indirection in structure info deduction

The way current type_args participates in the structure info deduction is to generate a structure info from the given type_args. Specifically,

Therefore, currently having type_args brings additional indirection to the structure info deduction. Switching to have ret_sinfo will cut off the indirection in this case.

S3. History of type_args

Prior to PR #306, Relax was directly reusing Relay’s the IR nodes, and type_args is a field of Relay’s CallNode, dedicated to polymorphism (per documentation). We didn’t have much chance to adjust the IR node design based on our own needs.

PR #306 introduces dedicated CallNode and other IR nodes for Relax. Though this PR doesn’t change constructs of any IR node, it enables the opportunity for us to iterate the IR node design. So we have no obstacles for iteration from this perspective.

Possible implications

If we switch to ret_sinfo, there are a few things we should be mindful:

T1. Visitor / Mutator recursing into ret_sinfo

Since StructInfo is possible to contain Expr, whenever a pass might do change to call_tir or ExternFunc call, it should recurse into ret_sinfo to recursively mutate the Expr inside, with the help of VisitExprDepStructInfoField method in ExprVisitor and/or ExprMutator.

T2. Validity of ret_sinfo

For a call_tir / ExternFunc call / call_builtin, not every structure info is valid to serve as the ret_sinfo. Since the purpose of having ret_sinfo is to indicate the structure info of the return value, to have clear semantics, we require ret_sinfo to be well-defined from the context. In other words, it should not contain any construct that has never appeared. For example,

@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
    y: R.Tensor((m, n), "float32") = R.call_packed(
        "vm.builtin.copy", x, R.Tensor((m, n), "float32")
    )
    return y

👆 This is valid because m and n are defined in the function signature.

@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
    m1 = tir.Var("m1", "int64")
    n1 = tir.Var("n1", "int64")
    y: R.Tensor((m1, n1), "float32") = R.call_packed(
        "my_packed_func", x, R.Tensor((m1, n1), "float32")
    )
    return y

👆 This is invalid since m1 and n1 are undefined. For this case, we require mandatory MatchCast:

@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
    y: R.Tensor("float32", ndim=2) = R.call_packed(
        "my_packed_func", x, R.Tensor("float32", ndim=2)
    )
    m1 = tir.Var("m1", "int64")
    n1 = tir.Var("n2", "int64")
    y1: R.Tensor((m1, n1), "float32") = R.match_cast(
        y, R.Tensor((m1, n1), "float32")
    )
    return y1

This can be a part of our well-formed check.

junrushao commented 1 year ago

This is an interesting proposal! Would you imagine all Call + MatchCast patterns could be replaced the proposed call-with-ret-sinfo?

MasterJH5574 commented 1 year ago

This is an interesting proposal! Would you imagine all Call + MatchCast patterns could be replaced the proposed call-with-ret-sinfo?

@junrushao Yes I imagine that can be possible indeed. In spite of that, as a first step I think we should go with a simple approach and keep the original semantics of CallNode, which is the main purpose of this proposal. And that is why I listed the implication point T2 as above, which suggests disabling “invalid” cases.

If we find it useful or helpful by incorporating ret_sinfo with the semantics of MatchCast in the future use and development, we can propose, discuss and decide to switch then.

slyubomirsky commented 1 year ago

I think it would be better to keep Call and MatchCast separate, since MatchCast is a binding while Call is not.

junrushao commented 1 year ago

I think it would be better to keep Call and MatchCast separate, since MatchCast is a binding while Call is not.

Yep, very early on (last Nov) I was thinking about combining them, but was convinced that they need to be separate IR nodes so that we could express more fine-grained semantics. Therefore, to be clear, I am not suggesting to remove MatchCast as an independent node, but instead was trying to point out there could be some overlapping in semantics if we use ret_sinfo

slyubomirsky commented 1 year ago

The rule for extern calls has some overlap with casting, but I think binding new shape vars should be left to MatchCast.

tqchen commented 1 year ago

another possible way to look at it is to ONLY allow intrinsics to specially define semantics for this extra field, such as call_tir, call_packed, since they are intrinsics. Along that specific line, we can have something similar to type_args, and name it struct_info_args, and define the semantics through the deduction function in each intrinsic.

slyubomirsky commented 1 year ago

another possible way to look at it is to ONLY allow intrinsics to specially define semantics for this extra field, such as call_tir, call_packed, since they are intrinsics. Along that specific line, we can have something similar to type_args, and name it struct_info_args, and define the semantics through the deduction function in each intrinsic.

This suggestion is more along the lines of what's in the draft specification. Since ops already use FInferStructInfo, this is a very easy rule to work with.

MasterJH5574 commented 1 year ago

Just did a round of update, changing the name to sinfo_args, type to Array<StructInfo>, and reserving the field only for intrinsic operators. This is mainly to keep clear boundary between the functionalities of sinfo_args and MatchCast. Since there is possibility that the struct info deduction of some intrinsic op might rely on multiple struct info, sinfo_args is made as an array of StructInfo, instead of a single one.

MasterJH5574 commented 1 year ago

Set up the tracking issue at #377