gelisam / klister

an implementation of stuck macros
BSD 3-Clause "New" or "Revised" License
130 stars 11 forks source link

compile to recursion-schemes #162

Open gelisam opened 1 year ago

gelisam commented 1 year ago

As the maintainer of recursion-schemes, I believe that the main reason recursion-schemes aren't used more often is simply because people are more comfortable with the unbounded-recursion style. As a result, they have to spend more time and mental energy when reading and writing code written in the recursion-schemes style than when it is written in the unbounded-recursion style. This is a cost which has to be paid on every single line of every single function, so of course it doesn't seem worth the cost when the benefits only appear on some lines of some functions.

What if it was possible to mix the two styles within the same function, so that the less familiar syntax is only used when it brings benefits?

Let's look at a concrete example of what this could look like. Consider the following program written in the unbounded-recursion style.

{-# LANGUAGE ImportQualifiedPost #-}
import Data.Set (Set)
import Data.Set qualified as Set
import Test.DocTest (doctest)

data Term
  = Var String
  | App Term Term
  | Lam String Term
  | Const Term
  | Unit
  deriving Show

-- |
-- >>> usedVars1 (Lam "x" (Lam "y" Unit))
-- fromList []
-- >>> usedVars1 (Lam "x" (Lam "y" (Var "x")))
-- fromList []
-- >>> usedVars1 (Lam "x" (Lam "y" (Var "z")))
-- fromList ["z"]
-- >>> usedVars1 (App (Var "f") (Var "x"))
-- fromList ["f","x"]
usedVars1
  :: Term
  -> Set String
usedVars1 (Var s)
  = Set.singleton s
usedVars1 (App t1 t2)
  = usedVars1 t1 `Set.union` usedVars1 t2
usedVars1 (Lam s t)
  = Set.delete s (usedVars1 t)
usedVars1 (Const t)
  = usedVars1 t
usedVars1 Unit
  = Set.empty
usedVars1 (Int _)
  = Set.empty
usedVars1 (Add t1 t2)
  = usedVars1 t1 `Set.union` usedVars1 t2

-- |
-- >>> useConstWhenPossible1 (Lam "x" (Lam "y" Unit))
-- Const (Const Unit)
-- >>> useConstWhenPossible1 (Lam "x" (Lam "y" (Var "x")))
-- Lam "x" (Const (Var "x"))
-- >>> useConstWhenPossible1 (Lam "x" (Lam "y" (Var "y")))
-- Const (Lam "y" (Var "y"))
useConstWhenPossible1
  :: Term
  -> Term
useConstWhenPossible1 (Var s)
  = Var s
useConstWhenPossible1 (App t1 t2)
  = App (useConstWhenPossible1 t1)
        (useConstWhenPossible1 t2)
useConstWhenPossible1 (Lam s t)
  | s `Set.member` usedVars1 t
    = Lam s (useConstWhenPossible1 t)
  | otherwise
    = Const (useConstWhenPossible1 t)
useConstWhenPossible1 (Const t)
  = Const (useConstWhenPossible1 t)
useConstWhenPossible1 Unit
  = Unit
useConstWhenPossible1 (Int n)
  = Int n
useConstWhenPossible1 (Add t1 t2)
  = Add (useConstWhenPossible1 t1)
        (useConstWhenPossible1 t2)

The code is clear enough, but it has two issues:

  1. there are many repetitive lines which do nothing but making recursive calls and combining the results
  2. the runtime cost is quadratic in the number of nested Lam constructors, because usedVars1 is linear and is called once per Lam constructor.

The following recursion-schemes version fixes both problems:

{-# LANGUAGE DeriveTraversable, TemplateHaskell, TypeFamilies #-}
import Data.Foldable (toList)
import Data.Functor.Foldable (cata, embed, zygo)
import Data.Functor.Foldable.TH (makeBaseFunctor)

makeBaseFunctor ''Term

-- |
-- >>> usedVars2 (Lam "x" (Lam "y" Unit))
-- fromList []
-- >>> usedVars2 (Lam "x" (Lam "y" (Var "x")))
-- fromList []
-- >>> usedVars2 (Lam "x" (Lam "y" (Var "z")))
-- fromList ["z"]
-- >>> usedVars2 (App (Var "f") (Var "x"))
-- fromList ["f","x"]
usedVars2
  :: Term
  -> Set String
usedVars2
  = cata usedVarsF

usedVarsF
  :: TermF (Set String)
  -> Set String
usedVarsF (VarF s)
  = Set.singleton s
usedVarsF (LamF s varsUsedInBody)
  = Set.delete s varsUsedInBody
usedVarsF varsUsedInSubTerms
  = Set.unions (toList varsUsedInSubTerms)

-- |
-- >>> useConstWhenPossible2 (Lam "x" (Lam "y" Unit))
-- Const (Const Unit)
-- >>> useConstWhenPossible2 (Lam "x" (Lam "y" (Var "x")))
-- Lam "x" (Const (Var "x"))
-- >>> useConstWhenPossible2 (Lam "x" (Lam "y" (Var "y")))
-- Const (Lam "y" (Var "y"))
useConstWhenPossible2
  :: Term
  -> Term
useConstWhenPossible2
  = zygo usedVarsF useConstWhenPossibleF

useConstWhenPossibleF
  :: TermF (Set String, Term)
  -> Term
useConstWhenPossibleF (LamF s (varsUsedInBody, rewrittenBody))
  | s `Set.member` varsUsedInBody
    = Lam s rewrittenBody
  | otherwise
    = Const rewrittenBody
useConstWhenPossibleF pairs
  = embed (fmap snd pairs)

The way in which recursion-schemes fixes the two issues are:

  1. The generated TermF datatype has generated Functor, Foldable, and Traversable instances for manipulating the results of the recursive calls in a uniform way. This makes it possible to combine all the boring repetitive lines into a single, more abstract line.
  2. zygo computes both usedVars and useConstWhenPossible at the same time. It does so in a way which makes the partial results of usedVars, that is, the vars which are used by each sub-term, available to useConstWhenPossibleF. Thus, there is no need to make a linear-time call in order to obtain the information, it is already available for each sub-term.

However, there are many parts of the recursion-schemes style code which looks needlessly-complicated, especially to those who are unfamiliar with the details:

  1. the zygo, cata, and embed calls
  2. the TermF (Set String, Term) -> Term and TermF (Set String) -> Set String types

I thus propose a #lang which achieves something like this (but using a parenthesized syntax of course):

usedVars3
  :: Term
  -> Set String
usedVars3 (Var s)
  = Set.singleton s
usedVars3 (LamF s varsUsedInBody)
  = Set.delete s varsUsedInBody
usedVars3 varsUsedInSubTerms
  = Set.unions (toList varsUsedInSubTerms)

useConstWhenPossible3
  :: Term
  -> Term
useConstWhenPossible3 (Lam s body)
  | s `Set.member` usedVars3 body
    = Lam s (useConstWhenPossible3 body)
  | otherwise
    = Const (useConstWhenPossible3 body)
useConstWhenPossible3 rewrittenSubTerms
  = rewrittenSubTerms

Note how we sometimes pattern-match on Lam, and other times on LamF. The first indicates that we want to use the usual unbounded-recursion syntax, with explicit recursive calls, while the second indicates that we want the recursive positions to already contain the results of the recursive calls. The combined-but-more-abstract clauses, which don't match on a constructor, are assumed to expect a TermF r, not a Term, because this form is more commonly-used in the recursion-schemes style.

When using the Lam form, the #lang implementation will look for recursive calls and rewrites them away, by converting the right-hand side from the Lam form to the LamF form. If the recursive calls are in an unusual position in the code, the rewrite fails, just like a termination-checker who doesn't recognize that the recursive calls are made on smaller inputs. As a result of the rewrite, usedVarsF and useConstWhenPossibleF are generated in addition to usedVars3 and useConstWhenPossible3.

More ambitiously, the usedVars3 call which leads to quadratic performance should also be rewritten away. To do this, the #lang must remember that usedVars3 has a usedVarsF form which can be used in a zygo. The #lang silently rewrites all the other clauses which don't use usedVars3 to expect a TermF r rather than a Term (Set String, r), thus relieving the user from having to explicitly pattern-match on (_, r) or to use fmap snd in the combined-but-more-abstract clause.

Finally, something like #156 could be used to automatically insert missing calls to embed. In particular, the useConstWhenPossible3 rewrittenSubTerms = rewrittenSubTerms clause is producing a TermF Term instead of a Term, and is thus missing a call to embed, but in the recursion-schemes style, embed is such a common and information-less transformation that it is basically line noise, and so it seems clearer to allow it to be omitted than to require it to be written explicitly.

gelisam commented 1 year ago

@david-christiansen mentioned the following relevant literature:

  1. "A Cosmology of Datatypes, Reusability and Dependent Types", by Pierre-Evariste Dagand
  2. The Gentle Art of Levitation

We later determined that those resources are more relevant for #164 than for recursion-schemes. Or are they? We were discussing recursion-schemes when the topic of those two resources came up, so the two topics must be linked somehow, I just don't see how.

david-christiansen commented 1 year ago

I think that levitation is not relevant to Klister. Dagand's thesis is relevant in the sense that it treats the elaboration of pattern matching to encodings of datatypes, which seems to be what you're proposing to do here, and there may be some notations or technical devices to take from it.

gelisam commented 1 year ago

Right! And the other reference you mentioned, which elaborates pattern-matching on multiple arguments to nested pattern-matchings on a single argument, is Wadler's section on Case-Tree compilation, chapter 5 in SPJ's "The Implementation of Functional Programming Languages" book. The key, I believe, is to keep both a success continuation and a failure continuation so that we know what to do when a pattern matches and when it doesn't, as nested patterns and multiple arguments makes it more complicated than just trying the next pattern in a list.

gelisam commented 1 year ago

We discussed the following implementation plan.

Define a my-datatype macro which defines both a datatype Foo and its base functor FooT, and uses #168 to associate the two. Similarly associate each constructor MkFoo of FooT to the corresponding MkFooT constructor in FooT.

Define a my-defun macro which checks for a few well-known recursion patterns and rewrites the clauses, using the association to convert between Foo and FooT as appropriate.