exo-lang / exo

Exocompilation for productive programming of hardware accelerators
https://exo-lang.dev
MIT License
292 stars 28 forks source link

LHS of Exo Statements #553

Open SamirDroubi opened 8 months ago

SamirDroubi commented 8 months ago

Problem:

There are currently two Exo statements that have a LHS: reduce and assign. However, the LHS data is fused within the statement node:

stmt = Assign( sym name, type type, string? cast, expr* idx, expr rhs )
         | Reduce( sym name, type type, string? cast, expr* idx, expr rhs )
         | ...

Which leads to similar behavior of the equivalent cursors at the API:

class AssignCursor(StmtCursor):
    def name(self) -> str:
    def idx(self) -> ExprListCursor:
    def rhs(self) -> ExprCursor:

This generally leads to less-ergonomic code. Consider the following example:

get_symbol_dependencies(proc, cursor) # Gets all the symbols that some cursor depends on

@proc
def foo(a: f32[2], b: f32[2]):
    for i in seq(0, 2):
         a[i] = b[i]

# I want the symbol dependencies of the assign statement
get_symbol_dependencies(proc, assign) # Okay great I gave a cursor that points to the statement and got {'a', 'i', 'b'} back!

# I want the symbol dependencies of the LHS of the assign statement
# but no way for me to point to the LHS!
# I could write the following
{assign.name()} + get_symbol_dependencies(proc, assign.idx()) # Should give me {'a'} + {'i'} = {'a', 'i'}

# But that's less-ergonomic than writing:
get_symbol_dependencies(proc, assign.lhs()) # I get the result {'a', 'i'}
# This final version is clean, direct, and self-explanatory

Consider this other example:

def get_buffer_accesses(proc, cursor, name):
     """
     Returns a list of cursors to all accesses to the buffer named `name` in the subtree of `cursor`
     """
     cursors = get_cursors_in_subtree(proc, cursor)
     check = lambda c: isinstance(c, (ReadCursor, ReduceCursor, AssignCursor)) and c.name() == name 
     # ugh I had to specify three types
     return filter(check, cursors)

# Not only I had to specify three types which seems excessive, 
# I have also had to go into the mental effort of remembering all the types that can access a buffer.
# Alternatively, I would like to write something like the following:
      check = lambda c: isinstance(c, AccessCursor) and c.name() == name

Proposal: I have a few options below, but these are just ideas I came up with; I am not really a big fan of any of them. I am sure this problem isn't related to Exo in particular and there is a more canonical way of dealing with this if anyone can provide pointers.

Option 1: Extend the expression type

module LoopIR {
    proc = ...
    fnarg  = ...

    stmt = Assign( sym name, type type, string? cast, expr* idx, expr rhs ) # Current
         | Reduce( sym name, type type, string? cast, expr* idx, expr rhs ) # Current
         Assign( expr lhs, string? cast, expr rhs ) # Proposal 
         | Reduce( expr lhs, string? cast, expr rhs ) # Proposal 
         | ....

    expr = Read( sym name, expr* idx ) # Current
                Access(sym, name, expr *idx, read bool, write bool) # Proposal 
         | ...

Advantages:

  1. Packages all types of accesses to buffer under one type

Disadvantages:

  1. This will require some dynamic checking to make sure that LHS is always an Access and not some other random expression.
  2. Expressions were side-effect free, but it might be confusing to have an expressions that can be written to?
  3. It is unclear what the implications of having access as a type of expression on scheduling operation:
    • It might actually make it clearer in some cases: e.g. bind_expr could potentially now bind a LHS.
    • In some other cases, it might not make sense to operate on a LHS which will require the scheduling op to reject LHS nodes

Option 2: Add a LHS type (just decoupling the LHS from the statements)

module LoopIR {
    proc = ...
    fnarg  = ...

    stmt = Assign( sym name, type type, string? cast, expr* idx, expr rhs ) # Current
         | Reduce( sym name, type type, string? cast, expr* idx, expr rhs ) # Current
         Assign( LHS lhs, string? cast, expr rhs ) # Proposal 
         | Reduce( LHS  lhs, string? cast, expr rhs ) # Proposal 
         | ....

    lhs = LHS(sym name, type type, expr *idx)  # Proposal 

Advantages:

  1. Leaves expressions untouched

Disadvantages:

  1. Feels hacky
  2. There are still two types that could access a buffer: LHS and Read
gilbo commented 7 months ago

It's common to have a slightly duplicated but restricted part of the grammar for lvalues. I'm somewhat in favor of option 2 because the IR structure prohibits ill-formed lvalues. But also, I think this gets into the cursor API design, which is different than the internal IR design. So, I'm a bit unsure/confused about what the tradeoffs are.

Look at the Python ASDL file https://github.com/python/cpython/blob/main/Parser/Python.asdl they use option 1

If possible, I think it's better for us to reduce the complexity of the user facing cursor API. Thus, yes, one could say "give me all accesses to a buffer" but if you take the entirely uniform view of returning a bunch of expression cursors, this is going to end up being very weird, since the lvalue expressions can't be operated on in the way that other expressions are operated on. So now you have to have a bunch of weird special casing anyways.

You can probably get very far with duck-typing. If all of the assignments, reductions and read cursors all have idx and name, then you can work pretty seamlessly.

Here is a possible solution to your first problem that requires no basic change

get_symbol_dependencies(proc, assign)

Maybe this seems strange, because assign should encompass both the left and right-hand sides, but for the purpose of this function, maybe it implicitly means the left-hand side? Or is get_symbol_dependencies supposed to be usable on arbitrary statements & expressions, not just reads and writes?

Regarding a solution to your second problem, consider adding useful queries to cursors:

check = lambda c: c.is_access() and c.name() == name
SamirDroubi commented 7 months ago

For the second issue, I have been, unknowingly, doing duck-typing. Although, I was worried of forward-compatibility of code like that because it is correct only assuming the current grammar and not necessarily any future changes. Maybe I am being too paranoid.

For the first one, yeah I would like an operation like get_symbol_dependencies to work on arbitrary statements and expressions.

For example, operating on a statement

if loop.name() not in get_symbol_dependencies(proc, stmt):
    # hoist stmt

For example, operating on an lhs/expression:

if loop.name() not in get_symbol_dependencies(proc, assign.lhs()):
   # stage this lhs around the loop
if loop.name() not in get_symbol_dependencies(proc, read):
   # stage this read around the loop

There are probably ways to define such operation where I can give the intent that I am looking for the lhs dependencies, but that will lead to users to need to learn more per-operation things (oh there is this flag to enable this mode, or maybe there is this other operation that is scoped to lhs).

gilbo commented 7 months ago

I think there is still a semantic issue here that’s being suppressed or conflated.

Why do you want all symbols that occur in an access, including symbols in the indexing expressions and the buffer symbol? This mixes up different kinds of variables. Why not also mix up any variables occurring in the right-hand side here too then?

If the point is to trace back through variable names to find every “symbol that this access depends on” then anything tracing back through this write or reduce access will also need to trace through the right-hand side, since the value written depends on the right-hand-side.

If the point is “no, I just want to know which control data this access depends on” then that’s a different concept anyway.

On Jan 20, 2024, at 6:41 PM, Samir Droubi @.***> wrote:

For the second issue, I have been, unknowingly, doing duck-typing. Although, I was worried of forward-compatibility of code like that because it is correct only assuming the current grammar and not necessarily any future changes. Maybe I am being too paranoid.

For the first one, yeah I would like an operation like get_symbol_dependencies to work on a arbitrary statements and expressions.

For example, operating on a statement

if loop.name() not in get_symbol_dependencies(proc, stmt):

hoist stmt

For example, operating on an expression:

if loop.name() not in get_symbol_dependencies(proc, assign.lhs()):

stage this lhs around the loop

if loop.name() not in get_symbol_dependencies(proc, read):

stage this read around the loop

There are probably ways to define such operation where I can give the intent that I am looking for the lhs dependencies, but that will lead to users to need to learn more per-operation things (oh there is this flag to enable this mode, or maybe there is this other operation that is scoped to lhs).

— Reply to this email directly, view it on GitHub https://github.com/exo-lang/exo/issues/553#issuecomment-1902483630, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKZHLA5YFQY5AF5YX4CMS3YPR55JAVCNFSM6AAAAABCDIFQJCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMBSGQ4DGNRTGA. You are receiving this because you commented.

SamirDroubi commented 7 months ago

I think the semantics of the operation I was thinking of is "Get me all the symbols in the subtree of the cursor", but the name I chose might have been confusing. And maybe the semantics "Give me the symbol dependencies of the LHS of an assignment" should also include the rhs.

But I do agree with you that my example no longer feels great because of the issue of conflating two symbol types (control and data). I could instead write:

if loop.name() not in get_symbol_dependencies(proc, assign.idx()):
   # stage this lhs around the loop

which would achieve the same thing in my example.

I think maybe we can punt this issue until there is a great example that would really require such a wide intervention.