joelberkeley / spidr

Accelerated machine learning with dependent types
Apache License 2.0
72 stars 4 forks source link

Enable observable sharing in `Expr` via "let" #372

Closed joelberkeley closed 4 months ago

joelberkeley commented 8 months ago

Currently we capture the graph as a mapping of node positions to nodes: SortedMap Nat Expr. Then tensors are just Nat indices into this graph.

The alternative approach is to include a "let" construct in Expr by adding Let and Var nodes

class Expr:
  Var : Nat -> Expr
  Let : Nat -> Expr -> Expr -> Expr
  ...

Example for

let x = 7
    y = x + x
 in y * y

is

Let 0 (Lit 7) (Let 1 (Add (Var 0) (Var 0))) (Mul (Var 1) (Var 1)))

The second option may allow easier autodiff implementation since it's the syntax used in this paper.

I'm not clear on the mechanics of how to construct the graph: whether a tensor would be an Expr instead, and the only global aspect would be the counter variable (like we had before but without the need to merge trees). It might be that this approach would allow us to use ST Nat instead of linear mutable arrays, which would remove a dependency. I imagine that would be as fast, but we'd have think carefully about that.

joelberkeley commented 4 months ago

iiuc, stack machines (a List of Exprs), which is what we have now, scale better than Let-based ASTs