tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Proposal][IR] Tracking Purity in Relax #402

Open slyubomirsky opened 1 year ago

slyubomirsky commented 1 year ago

This issue was raised earlier in #344. I think that given the fact we have recently made a lot of changes to the IR, this might be a good time to pursue this relatively simple change, though there are some details to work out.

Dataflow blocks in Relax require that there not be any control flow or side effects, which allows for much simpler graph-like optimizations within those blocks. It is relatively simple to enforce the constraint that there not be control flow (simply prohibit recursive calls and If expressions in dataflow blocks), but presently the compiler makes no attempt to enforce the lack of side effects. I believe that it is feasible to enforce this constraint and that it would not require large changes to the language.

Proposal: Track Purity in FuncStructInfo

We could add a single flag (pure) to FuncStructInfo to indicate whether a given function is pure. We can automatically infer whether a function is pure by the following very simple heuristic:

We consider functions to be pure only if they are pure for all inputs and impure if they are impure for any input.

The Need for a Manual Override

The heuristic above is actually overly simplistic, as it is possible for a function to be pure (the definition being that it has no visible side effects and does not modify any value other than the one it returns) while calling an impure function. For example, it can call an impure function that mutates a value that is never used or returned, or it can allocate a value, mutate only that value, and return the final value. In both of these cases, the function does not have any visible side effects, so it satisfies the definition of purity but does not satisfy the criteria above.

Automatically verifying that such cases are pure would likely require more sophisticated analysis of effects (e.g., tracking which values are mutated, which would also require an alias analysis). This would essentially entail creating an effect system, which would greatly increase the complexity of Relax.

Instead of an effect system, a much simpler approach would be to request a user annotation to assert that the function is pure. That could be accomplished using a function attribute like AssertPure (we could have @R.pure as a decorator for syntactic sugar in the parser, potentially).

Not Considered an Effect: Type/Shape Casts

In principle, using MatchCast can create a possible "visible effect" if the shape check fails. As discussed in the specification draft, it would be prudent for us not to consider divergence of this kind to be a "visible effect," since it would be extremely impractical not to be able to do any kind of dynamic shape checks inside dataflow blocks and considering these to be "visible" would also prevent the compiler from optimizing away unnecessary or unused shape checks (otherwise, doing so could change the "visible behavior" of the program). Hence, I believe it would be best for the purity check not to consider MatchCasts (or implicit shape checks that occur on function calls) to be visible effects.

Implementation

Function and dataflow block purity can be checked during normalization, since the heuristic is easy enough to apply and check for. Checking the heuristic for the function only requires checking function/operator calls in its body and checking dataflow blocks only requires checking calls that appear inside the bindings. It may be worth factoring out the purity check to a separate pass that is called from the normalizer to avoid having that become a giant monster pass.

Further Questions/Points of Discussion

  1. Should we include the ability to override a specific call site? E.g., assert that we permit one specific call to an impure function in a dataflow block. This could be accomplished with a Call attribute. Making exceptions can occasionally be useful, but we may want to consider what the specification will promise (e.g., no guarantee that the call will happen and if it happens, it may be called more than once).
  2. Relatedly, a single flag may be too simplistic for checking operator purity in some cases. An alternative might be using a macro (fInferPurity, say) to check it at each call site. This could be useful for very versatile operators like call_tir.
  3. Mutual recursion is potentially tricky to check automatically. One option would be to assume mutually recursive functions are pure by default and re-check only if one or both are found to contain impure calls. Another is to require purity to be annotated on mutual recursive functions.
slyubomirsky commented 1 year ago

Results of discussion in the Feb. 14, 2023, community meeting: There was no opposition voiced to this proposal. I will implement this change and we can determine if any modifications will be necessary. The consensus was in favor of having a simple binary flag on operators (we can revisit and use a macro if we do need something more granular or sophisticated) and in having a per-call override via an attribute.

slyubomirsky commented 1 year ago

Implementation issues that have arisen that might be good to discuss:

Dealing with mutual recursion

Dealing with mutually recursive global functions can get tricky, especially given how normalization is currently implemented (transforming the AST to ANF and also filling in StructInfo). We can infer purity using a second pass (e.g., with a fixed point algorithm that checks all global functions, sees if any are impure, then rechecks until they stop changing, while also updating any types encountered along the way), but this adds much complexity to normalization. We can avoid this and continue doing normalization in a single pass if we require purity to be annotated on global functions that are mutually recursive (doing so for simple recursion as well would also be a simplification). Would that be reasonable? We require types to be annotated so it would not be unprecedented. However, it would require importers to reason about whether code they produce is pure or not, which may not be ideal. Possible syntax for annotations:

# true would be the default, since most functions are pure
@R.function(pure=True)
def f(...): ...

# force_pure => treat as pure even if an effect is found
@R.function(force_pure=True)
def g(...): ...

These options would compile to different attributes on the function.

Other possible syntax: Use another decorator (not sure it's better)

@R.function
@R.pure
def f(...): ...

(Note: We probably should provide a mechanism to specify the purity regardless of whether we require it. Happy to take other suggestions)

Overriding purity checks on specific calls

The original proposal was to use a call attribute for this purpose, but this would be difficult to implement. In particular, operators expect a specific type of attribute and it's not possible to add an extra field onto them. Possible alternative ways to approach this feature:

  1. Use an operator to annotate these cases, e.g., R.force_pure(op(...)). This does not complicate the AST, but may be difficult to use with normalization, since it would be normalized to

    x = op(...)
    y = R.force_pure(x)

    This would complicate the implementation somewhat, since it would be necessary to check the entire function body to check whether an impure function is bound to a var and that var is later passed to force_pure. Not impossible, though.

  2. Add a field to Call nodes for the purpose of annotating purity. This has the disadvantage of changing the AST.

I think the first option would be feasible to implement without too much additional complexity, so we can go with it.

  1. Edit: @tqchen suggests an intrinsic more like R.call_pure(op, args). This would avoid the normalization issue above, though we would have to make sure it's valid to pass an Op node as an argument to it (should not be a terrible complication)

The last suggestion seems like the best option, involving the least complexity in the check and the least disruption to the rest of the AST.