microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Use dataflow to control which functions are safe to lift out of "if" #972

Open acl33 opened 2 years ago

acl33 commented 2 years ago

Introduction

953 introduces lifting rules, for example "add-of-if" to "if-of-add",

(add (if p 5 3) 2)   ==lift-if==>   (if p (add 5 2) (add 3 2))

or more generally "call-of-if to if-of-call" aka "lift_if_from_call_arg".

And "if-of-let to let-of-if" (aka "lift_let_from_if_true"/"false"):

(if p (let (x a) expr1) expr2)  
;==lift-if==>  
(let (x a)  (if p expr1 expr2))  ; checking x not in freeVars of expr2 or p

These rules are crucial to enable CSE, or to move ifs around in order to eliminate redundant comparisons such as (if p (if p t1 f1) f2) ==> (if p t1 f2).

Problem

However without side conditions these rules can result in executing code earlier than they would have otherwise, with disastrous results -- a program which never crashed can be rewritten into one that does:

(if (inRange i arr)
    (let (x (index i arr)) ; index is protected by "if inRange"
        expr1) 
    expr2)
;==lift-let==>
(let (x (index i arr)) ; index no longer protected, may well crash
    (if (inRange i arr) 
        expr1
        expr2))

Indeed, it's likely that it will crash, because the programmer was protecting the call in the first place. (Ignore theorem-proving our way out of this, that's a hope, not a guarantee.)

We can never lift a throwing function

As another example, consider

(if fred 
  (index i arr)
  e)

where fred is not protecting the call to index. This cannot be rewritten to

(let (x (index i arr)) 
  (if fred
    x
    e))

because the semantics change:

The former has the following semantics - we must assume all four boxes are possible unless proven otherwise: fred is true fred is false
i in range value e
i out of range exception e
The putative rewrite would have different semantics in the case where fred is false and i is out of range: fred is true fred is false
i in range value e
i out of range exception exception

Objective

However, almost all of our code uses "index", and proving nothrow is hard. So while we disallow lifting the "if"s above, we would like to be able to deal with:

(if (inRange i arr)
    (if fred
        (index i arr)
        expr1)
    expr2))

and lift the inner "index" outside of "if fred". (The precise formulation of the end result is left until after the introduction of new syntax.)

Solution: ifs with proofs

Putting aside theorem proving, we would like a way for the programmer to tell us which functions depend on which conditions. This is done using "proofs" as outlined in #953. First, we provide a new function index-with-proof, defined like index, but with a new argument of type Proof.

(edef index T ((i : Integer) (x : Tensor T))
(edef index-with-proof T ((p : Proof) (i : Integer) (x : Tensor T)))

We won't define Proof yet, but it will turn out to be always erased at runtime, so it might as well be the empty tuple. It's also going to turn out that index-with-proof is exactly a call to index, with exactly the same behaviour. The only function of the edef is to ensure we can't inline it away.

So why pass the Proof at all? Well, let's see where we get Proofs from. A "Proof" is a variable that exists in only one branch of an if. The programmer uses it to indicate computations that depend for their correctness on being on a certain branch. In the example above, we used a comment to show which check was protecting the call to index:

(if (inRange i arr)
    (let (x (index i arr)) ; index is protected by "if inRange"
        expr1) 
    expr2)

Now we use a "proof", generated by the construct if-with-proof, to tell the compiler what we meant by our comment. What does that look like?

(if-with-proof i_ok (inRange i arr) ; 'i_ok' is just a variable name, chosen by the user
    (let (x (index-with-proof i_ok#t i arr)) ; index is protected by "i_ok on its true branch"
        expr1) 
    expr2)

Let's break that down. The new construct is (if-with-proof IDENT EXPR EXPR EXPR), which can be considered to rewrite to the following, where let! is an "immovable" synonym for let, which can never be lifted over an if.

(if-with-proof var cond t_body f_body)
; ==> 
(if cond
   (let! (var#t (dummy Proof))
       t_body)
   (let! (var#f (dummy Proof))
       f_body)

Now, if t_body doesn't refer to var#t, it can happily be lifted. But if it does use var#t, it can't.

So let's see it in action. Our "inRange j" variant above will be written as follows. (We're assuming the inRange j was not a typo, so the programmer was using it in a way that uses its own proof jay_ok)

(if-with-proof jay_ok (inRange j arr)
    (let (x (index-with-proof eye_ok#t i arr)) ; eye_ok#t must be coming from some proof above.
        expr1) 
    expr2)

and will rewrite happily to

(let (x (index-with-proof eye_ok#t i arr)) ; eye_ok#t not in freeVars of (inRange j arr).
    (if-with-proof jay_ok (inRange j arr)
        expr1) 
    expr2)

Updated syntax

We can use lams for the proof-taking parts, make "index_with_proof" a (type-parameterized) primitive function, and even remove If as a type of ASTNode/Expr, by treating if as a syntactic shorthand, as follows:

(if p x y) <==> (if_with_proof p (lam (fresh: Proof) x) (lam (fresh2: Proof) y))

where fresh and fresh2 are not used within x or y. This means we only need one copy of each lift_if rule, dealing with if_with_proof.

Throw / nothrow functions

As above, we must be conservative, and avoid lifting any potentially-throwing function out of an if.

Worked Example

See comment below

acl33 commented 2 years ago

I think there are at least two issues here, but not sure how cleanly they separate:

awf commented 2 years ago

Alan, you also had an idea involving lambdas, maybe you can sketch that out too? I thought it would amount to CPS, but perhaps it allows us to avoid new syntax?

awf commented 2 years ago

Alan, you also had an idea involving lambdas, maybe you can sketch that out too? I thought it would amount to CPS, but perhaps it allows us to avoid new syntax?

Ah, I see there was a parenthetical remark, in the middle of the original explication. Moved to bottom.

awf commented 2 years ago

See edit at 'we can't allow lifting here'

acl33 commented 2 years ago

See edit at 'we can't allow lifting here'

Hmmm. Yes. What is an example where we should allow lifting, indeed.

acl33 commented 2 years ago

Does this work?

(if (inRange i arr)
    (if fred
        (let (x (index i arr)) x)
        expr1)
    expr2))

Which we can lift to

(if (inRange i arr)
    (let (x (index i arr))
         (if fred x expr1))
    expr2)

Really, the whole proofs system is a thing for converting maybe-throw calls into never-throw calls....

awf commented 2 years ago

Really, the whole proofs system is a thing for converting maybe-throw calls into never-throw calls....

Indeed—so perhaps the rule is very simple: "we can only lift never-throwing expressions, or always-executed expressions."

A consequence of this rule is that an important case, namely index, can not be lifted. So we introduce a mechanism—proofs—which permits us to replace index with a never-throw variant.

acl33 commented 2 years ago

Note we will still need rewrites like:

(if_with_proof true
               (lam (p : Proof) e_t)
               e_f)
; ===>
(let (p (dummy Proof)) e_t)
acl33 commented 2 years ago
; Start
(if (inRange i arr)
    (if fred
        (index i arr)
        expr1)
    expr2)
; Rephrasing Index -> IWP
; (index i a) --> (if_with_proof (inRange i a) (p#t -> index_with_proof p#t i a) (p#f -> throw))
(if (inRange i arr)
    (if fred
        (if_with_proof (inRange i arr)    ; Would be better to generate proof at line 9
            (p#t -> index_with_proof p#t i arr)
            (p#f -> throw))
        expr1)
    expr2)

; Destination - removing the redundant test of inRange, and the ostensible possibility of "throw"
(if_with_proof (inRange i arr)
    (p#t -> (let (x (index_with_proof p#t i arr))
                 (if fred x expr1)))
    (p#f -> expr2))

; lift the if_with_proof above the "if fred"
; (if p (if q x y) f) ==> (if q (if p x f) (if p y f))   ; duplicates f
; side condition: *q* does not throw
; (if p (if_with_proof q (q#t -> x) (q#f -> y)) f) ==> (if_with_proof q (q#t -> (if p x f)) (q#f -> (if p y f)))
; side conditions: q does not throw; q#t and q#f not free in f (e.g. ABU)
(if (inRange i arr)
    (if_with_proof (inRange i arr)
        (p#t -> if fred
            (index_with_proof p#t i arr)
            expr1)
        (p#f -> if fred
                               throw
                               expr1))
    expr2)

; Repeated "if"
; (if p (if p x y) e2) ==> (if p x e2)
; (if p (if_with_proof p (p#t -> x) (p#f -> y)) e2) ==> (if_with_proof p (p#t -> x) (p#f -> e2))
; include converting outer if ==> if_with_proof
;                via (if p x y) ==> (if_with_proof p (_ -> x) (_ -> y))
                              ; Equivalently, but with outer "if" already converted to if_with_proof:
                              ; (if_with_proof p
                              ;                (p1#t -> (if_with_proof p (p2#t -> x) (p2#f -> y)))
                              ;                (p1#f -> e2))
                              ; ===>
                              ; (if_with_proof p
                              ;                (p1#t -> (let (p2#t p1#t) x))
                              ;                (p1#f -> e2))
(if_with_proof (inRange i arr)
               (p#t -> if fred
                          (index_with_proof p#t i arr)
                          expr1)
               (p#f -> expr2))

; Now new_bind (e) ==> (let (x e) x)
(if_with_proof (inRange i arr)
               (p#t -> if fred
                          (let (x (index_with_proof p#t i arr)) x)
                          expr1)
               (p#f -> expr2))

; Now lift_let_over_if_true
; (if p (let (x e) body) f)) ==> (let (x e) (if p body f)) ; side condition: e does not throw
(if_with_proof (inRange i arr)
               (p#t -> (let (x (index_with_proof p#t i arr))
                            (if fred x expr1)))
               (p#f -> expr2))
; This is better than the supposed destination
(if_with_proof (inRange i arr)
               (p#t -> (let (x (if_with_proof (inRange i arr)    ; redundant if_with_proof eliminated
                                             (p2#t -> index_with_proof p2#t i arr)
                                             (p2#f -> throw)))   ; ostensible possibility of "throw" eliminated
                            (if fred x expr1)))
               (p#f -> expr2))

; Possible rules?
; (if_with_proof p (p#t -> x) (p#f -> y)) ==> (if p x y) ; side condition: x, y do not refer to p#t, p#f   ; not applicable here
; (if_with_proof (inRange i arr) (p#t -> (index_with_proof p#t i arr)) (p#f -> throw))  ==> (index i arr)    ; fails to match here
;                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ we have stuff wrapped around this

; Not equivalently - more general case where predicates are different
; (if_with_proof p1 (p1#t -> (if_with_proof p2
;                                           (p2#t -> x)     ; x refers to p1#t, p2#t
;                                           (p2#f -> y)))   ; y refers to p1#t, p2#f
;                   (p1#f -> e2))                           ; e2 refers to p1#f only
; === side condition: p2 does not refer to p1#t, p2 does not throw ===>
; (if_with_proof p2 (p2#t -> (if_with_proof p1
;                                           (p1#t -> x)
;                                           (p1#f -> e2)))
;                   (p2#f -> (if_with_proof p1
;                                           (p1#t -> y)
;                                           (p1#f -> e2))))

; A more complicated example, using build
; Start
(if (inRange i arr)
    (build 4 (j ->
        (if fred[j]
            (index i arr)
            expr1)))
    expr2)
; Rephrasing Index -> IWP
; (index i a) --> (if_with_proof (inRange i a) (p#t -> index_with_proof p#t i a) (p#f -> throw))
(if (inRange i arr)
    (build 4 (j ->
          (if fred[j]
               (if_with_proof (inRange i arr)    ; Would be better to generate proof at line 9
                    (p#t -> index_with_proof p#t i arr)
                    (p#f -> throw))
               expr1)))
    expr2)

; After above, but yet to lift outside of build:
(if_with_proof (inRange i arr)
    (p#t -> (build 4 (j ->
               (let (x (index_with_proof p#t i arr))
                            (if fred[j] x expr1))))
    (p#f -> expr2))

; Destination - removing the redundant test of inRange, and the ostensible possibility of "throw"
(if_with_proof (inRange i arr)
    (p#t -> (let (x (index_with_proof p#t i arr))
                 (build 4 (j ->
                     (if fred[j] x expr1)))))
    (p#f -> expr2))