microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Lift "always-evaluated" calls out of if #1035

Open acl33 opened 2 years ago

acl33 commented 2 years ago

This is a follow-up after #972.

Consider a program such as:

(if fred 
  (let (x (index i arr)) e[x]) 
  (let (y (index i arr)) e2[y]))
;==lift-let-ideal ==>
(let (x (index i arr)) 
  (if fred 
    e[x] 
    (let (y (index i arr)) e2[y])))
;==lift-let-ideal==cse-bind==>
(let (x (index i arr)) 
  (if fred 
    e[x]
    e2[x]))

The "lift-let-ideal" rewrite is so named because it breaks the usual contract for lift-let as it has lifted a potentially-exception-throwing index out of the let - but, we know this is safe (cannot cause the program to throw new exceptions), because index i arr was always going to be evaluated.

Rephrasing with proofs, the first above is:

(if_with_proof fred
    (lam (fred#t: Proof)
         (let (x (if_with_proof (inRange i arr)
                                (lam (eye_ok1#t: Proof) (index_with_proof eye_ok1#t i arr))
                                (lam (eye_ok1#f: Proof) (raise OutOfBounds))))
              e[x]))
    (lam (fred#f: Proof)
         (let (y (if_with_proof (inRange i arr)
                                (lam (eye_ok2#t: Proof) (index_with_proof eye_ok2#t i arr))
                                (lam (eye_ok2#f: Proof) (raise OutOfBounds))))
              e2[y])))

The two lets (x and y) can each be rewritten with straightforward "lift_if_from_let_rhs" (which duplicate e[x] and e[y]) and then simplifying (let (v (raise OutOfBounds)) ....) ==> (raise OutOfBounds) to:

(if_with_proof fred
    (lam (fred#t: Proof)
         (if_with_proof (inRange i arr)
                        (lam (eye_ok1#t: Proof) (let (x (index_with_proof eye_ok1#t i arr)) e[x]))
                        (lam (eye_ok1#f: Proof) (raise OutOfBounds))))
    (lam (fred#f: Proof)
         (if_with_proof (inRange i arr)
                        (lam (eye_ok2#t: Proof) (let (y (index_with_proof eye_ok2#t i arr)) e2[y]))
                        (lam (eye_ok2#f: Proof) (raise OutOfBounds))))

To combine the common if-conditions (inRange i arr) we now need a rule something like

(rule "if_interchange"
    (if_with_proof p   ; pattern - note the two "q"s must be identical
        (lam (p#t: Proof) (if_with_proof q (lam (q#t: Proof) e1[p#t, q#t]) (lam (q#f: Proof) e2[p#t,q#f])))
        (lam (p#t: Proof) (if_with_proof q (lam (q2#t: Proof) e3[p#f, q2#t]) (lam (q2#f: Proof) e4[p#f, q2#f]))))
    (if_with_proof q   ; replacement - now only one "q", but two copies of "p"
        (lam (q#t: Proof) (if_with_proof p (lam (p#t: Proof) e1[p#t, q#t]) (lam (p#f: Proof) e3[p#f, q#t])))
        (lam (q#f: Proof) (if_with_proof p (lam (p2#t: Proof) e2[p2#t, q#f]) (lam (p2#f: Proof) e4[p2#f, q#f])))))

Note the renames of proof variables free in e2/e3/e4 to maintain All Binders Unique (#807); the usual side-conditions on variables not escaping their binders apply.

The "if_interchange" gets us to:

(if_with_proof (inRange i arr)
    (lam (eye_ok#t: Proof)
         (if_with_proof fred
                        (lam (fred1#t: Proof) (let (x (index_with_proof eye_ok#t i arr)) e[x]))
                        (lam (fred1#f: Proof) (let (y (index_with_proof eye_ok#t i arr)) e2[y]))))
    (lam (eye_ok#f: Proof)
         (if_with_proof fred
                        (lam (fred2#t: Proof) (raise OutOfBounds))
                        (lam (fred2#f: Proof) (raise OutOfBounds)))))

Finally to common up the two identical index_with_proofs we need another rule - this parallels "if_interchange" which could have been called "lift_if_from_both_arms_of_if":

(rule "lift_let_from_both_arms_of_if"
    (if_with_proof p
                   (lam (p#t: Proof) (let (x common) e[x]))
                   (lam (p#f: Proof) (let (y common) e2[y])))
    (let (x common) (if_with_proof p
                                   (lam (p#t: Proof) e[x])
                                   (lam (p#f: Proof) (let (y x) e2[y])))))

Note the parallel between "if_interchange"/"lift_if_from_both_arms_of_if", "lift_let_from_both_arms_of_if", and the conventional "if_both_same":

(rule "if_both_same" (if p e e) e) ; might be called "lift_everything_from_both_arms_of_if"

The "lift_(if/let)_from_both_arms_of_if" are partial versions of "if_both_same".

Proof-sinking aka assert-pushing

Here's another case (2) where we would like to CSE the two indexs:

(let (x (index i arr))
     (add x
          (if fred
              (let (y (index i arr)) e[y])
              e2)))

Rather than the complexities of lifting y up out of if fred (which is not locally safe - it's only safe within the scope of x), another way is to allow x to be sunk/pushed-down:

(let (x (index i arr))
 (add x
     (if fred (known-binding (x (index i arr)) (let (y (index i arr)) e[y]))
            e2)))

what is this known_binding? One can imagine a rewrite: (known-binding (x e) (let (y e) body)) ==> (known-binding (x e) (let (y x) body)) which is much like the cse-bind rule (or one variant thereof): (let (x e) (let (y e) body)) ==> (let (x e) (let (y x) body)) However known-binding differs from let in that it would always be safe to delete known-binding's, whereas let can only be deleted when the bound variable is unused in the body.

Note that known-binding (x e) body can be just assert (eq x e) body if we are happy with a general rule assert cond body ==> body. Which is to say, we can use assert and "assert-pushing" - just that we'll only be able to do the optimization when in "release mode" when asserts can be removed. (In "debug mode" we cannot remove asserts.)

If we rephrase case (2) with if_with_proof:

(let (x (if_with_proof (inRange i arr)
                       (lam (i_ok#t: Proof) (index_with_proof i_ok#t i arr))
                       (lam (i_ok#f: Proof) (raise OutOfBounds))))
     (add x (if fred
                (let (y (if_with_proof (inRange i arr)
                                       (lam (i_ok2#t: Proof) (index_with_proof i_ok2#t i arr))
                                       (lam (i_ok2#f: Proof) (raise OutOfBounds))))
                     e[y])
                e2))

"lift_if_from_let_rhs" on the outer let, and simplifying (let (x (raise OutOfBounds)) (add x ....)) ==> (raise OutOfBounds), gets us to:

(if_with_proof (inRange i arr)
               (lam (i_ok#t: Proof)
                    (let (x (index_with_proof i_ok#t i arr)))
                         (add x (if fred
                                    (let (y (if_with_proof (inRange i arr)
                                                           (lam (i_ok2#t: Proof) (index_with_proof i_ok2#t i arr))
                                                           (lam (i_ok2#f: Proof) (raise OutOfBounds))))
                                         e[y])
                                    e2))
               (lam (i_ok#f: Proof) (raise OutOfBounds))))

We need some way to connect the two if_with_proofs (of the same predicate); while the following rule has been suggested, it only works when the two if_with_proofs are adjacent:

(rule "proof_known_from_if"   ; safe, but not general enough
    (if_with_proof p
                    (lam (p#t: Proof) (if_with_proof p (lam (p2#t: Proof) e) (lam (p2#f: Proof) e_impossible)))
                    (lam (p#f: Proof) e2))
    (if_with_proof p
                  (lam (p#t: Proof) (let (p2#t p) e))
                  (lam (p#f: Proof) e2)))

In the example case (2) with fred, there is intervening stuff between the two if_with_proofs so we need some way to "assert-push" information downwards. A possibility is:

(rule "introduce_proof_knowledge"
    (if_with_proof p (lam (p#t: Proof) e) (lam (p#f: Proof) e2))
    (if_with_proof p
            (lam (p#t: Proof) (assert (is_proof p#t p) e))
            (lam (p#f: Proof) (assert (is_proof p#f (not p)) e2)))) 

plus normal assert-pushing and then

(rule "known_proof"
     (assert (is_proof p#_ of_what)
             (if_with_proof of_what
                            (lam (p2#t: Proof) e)
                            (lam (p2#f: Proof) e2)))
     (assert (is_proof p#_ of_what)   ; assert can be deleted by another rule if no longer needed
             (let (p2#t p#_) e)))

This allows us to get to:

(if_with_proof (inRange i arr)
               (lam (i_ok#t: Proof) (let (x (index_with_proof i_ok#t i arr))
                                         (add x (if fred
                                                    (let (y (let (i_ok2#t i_ok#t) (index_with_proof i_ok2#t i arr)))
                                                         e[y])
                                                    e2))))
               (lam (i_ok#f: Proof) (raise OutOfBounds)))

as desired.

@simonpj also suggested that we might be able to handle the sinking case via if_with_proof by allowing (lam (x#t: Proof) ...) to return the Proof x#t, which could then be passed into another index_with_proof or if_with_proof. I'm not sure how we'd rewrite into such a situation (?).