google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.56k stars 106 forks source link

Auto-cancellation of `ordinal . unsafe_from_ordinal` #1283

Open axch opened 1 year ago

axch commented 1 year ago

Why

Notionally, an Ix a instance is supposed to define a bijection between a and the first size a natural numbers. In particular, on those natural numbers, ordinal and unsafe_from_ordinal are supposed to be inverses. Do we want to start using that in the compiler by replacing that composition with id when it is known to occur?

To wit, by the time we get to the hardware, for loops are counting integers, and arrays are indexed by integers. It seems nice to try to avoid round-tripping through the a value, if the main thing we need is its ordinal.

This should speed up both the generated code and the compiler, the former because it avoids running useless compositions of ordinal . unsafe_from_ordinal, and the latter because it avoids inlining and repeatedly compiling them. It should also simplify the compiler by better controlling how far user-defined Ix instances can travel.

What

In detail, Dex for loops are implemented thus: When the compiler sees for idx:a. body (where body has type b), it eventually generates imperative code that basically looks like this (with some abuse of notation):

dest : Ref (a => b) = <allocate or reuse a destination>
for_ i:(Fin (size a)).
  idx : a = unsafe_from_ordinal i  -- Except the instance is resolved and the method is inlined
  ithdest : Ref b = indexDest dest (ordinal idx)
  place ithdest <compile body>

Also, when the compiler sees array indexing (i.e., xs[idx]), it generates

i' = ordinal idx
indexValue xs i'

Seems like maybe we should just remember that i is in scope, and instead of constructing ordinal idx, just insert a reference to i.

Now, we do still need to generate idx, because the body may use it for something other than indexing as well (e.g., unpacking a tuple-typed index), but in the case of a pure map, dead code elimination should remove it.

(For what it's worth, something similar happens with table literals, except that we know the body doesn't read the index at all.)

Why not

The main problem I can imagine with doing this is that it assumes that ordinal really is a left inverse of unsafe_from_ordinal in all Ix instances, which is not something we have machinery to check. The optimization is observable if ordinal has side-effects (e.g., via unsafe_io).

How

Probably the simplest way to do it would be to write a new pass that changes the types of all arrays to Fin <something>.

We would presumably only want to run this pass after we were done using semantic index information for occurrence analysis and any future loop splitting or such; and it has to run before (or within) Imp, because the Imp representation no longer has interesting index sets.

Thoughts?

duvenaud commented 1 year ago

One idea for making the bijection restriction clearer to the user would be to split the Ix a interface into two parts: A size method, and a Bijection interface that has an instance for Bijection a (Fin (size a)).

Then maybe the compiler could rely on user-defined bijections to do this kind of optimization more generally. However, I guess this would require bringing back the interpreter, and also allowing dependent-enough types to allow size a to appear in all the types.

On second thought, my proposal mostly just reinvents the existing Ix class with more complicated types.