microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Inline suffwdpass and sufrevpass of numerical functions used in activations #871

Closed toelli-msft closed 3 years ago

toelli-msft commented 3 years ago

Implements AB#19483

awf commented 3 years ago

Super! Post a before and after of a function where this makes a difference?

E.g. suffwd$relu3

toelli-msft commented 3 years ago

Before

(def
 [sufrev [gelu Float]] Float
 ((sarg : Float) (dtarg : Float))
 (let ((r_29 b_f_30) ([suffwdpass [sqrt Float]] 2.0))
  (let ((r_25 b_f_26) ([suffwdpass [div (Tuple Float Float)]] sarg
                                                              r_29))
   (let ((r_21 b_f_22) ([suffwdpass [erf Float]] r_25))
    (let (d$t_9 ([mul (Tuple Float Float)] 0.5 dtarg))
     (let ((d$t_34 d$t_36) ([sufrevpass [div (Tuple Float
                                                    Float)]] ([sufrevpass [erf Float]] ([mul (Tuple Float
                                                                                                    Float)] sarg
                                                                                                            d$t_9)
                                                                                       b_f_22)
                                                             b_f_26))
      (ts_add ([mul (Tuple Float Float)] ([add (Tuple Float Float)] 1.0
                                                                    r_21)
                                         d$t_9)
              d$t_34)))))))

After

(def
 [sufrev [gelu Float]] Float
 ((sarg : Float) (dtarg : Float))
 (let (sqrt_x ([sqrt Float] 2.0))
  (let (r_25 ([div (Tuple Float Float)] sarg sqrt_x))
   (let (d$t_9 ([mul (Tuple Float Float)] 0.5 dtarg))
    (ts_add ([mul (Tuple Float Float)] ([add (Tuple Float Float)] 1.0
                                                                  ([erf Float] r_25))
                                       d$t_9)
            ([div (Tuple Float Float)] ([div (Tuple Float
                                                    Float)] ([mul (Tuple Float
                                                                         Float)] ([mul (Tuple Float
                                                                                              Float)] 2.0
                                                                                                      ([mul (Tuple Float
                                                                                                                   Float)] sarg
                                                                                                                           d$t_9))
                                                                                 ([exp Float] ([neg Float] ([mul (Tuple Float
                                                                                                                        Float)] r_25
                                                                                                                                r_25))))
                                                            ([sqrt Float] ([pi (Tuple)])))
                                       sqrt_x))))))