mentat-collective / emmy

The Emmy Computer Algebra System.
https://emmy.mentat.org
GNU General Public License v3.0
405 stars 24 forks source link

Port, optimize AD benchmarks from Siskind + Pearlmutter's work #87

Open sritchie opened 3 years ago

sritchie commented 3 years ago

@barak and @qobi presented on a series of AD benchmarks that they performed as part of their work on Stalingrad:

scmutils didn't do terribly well for a few reasons, one of which is no compilation of the derivatives... but no matter!

The task here is to get the SICMUtils code for these benchmarks tidied up, measured and going faster.

I took a pass at getting the code mechanically ported from the two straightforward benchmarks that worked in scmutils, particle-FF and saddle-FF. These are described in the first paper linked above.

NOTE: probabilistic-lambda-calculus, probabilistic-prolog and backprop all have forward-mode examples too that would work with some porting. Also once I get sicmutils/sicmutils#226 up and running we can get the rest of the benchmarks running here too. I know they'd be faster with a more intelligent compile implementation that could handle multiple arguments more gracefully. This would need to generate multiple nested functions in forward mode.

;; Run these from sicmutils.env... I have some defensive g/ prefixes that we probably don't need.

(defn multivariate-argmin [f x]
  (let [g (D f)]
    (loop [x   x
           fx  (f x)
           gx  (g/transpose (g x))
           eta 1e-5
           i   0]
      (cond (<= (compare (g/abs gx) 1e-5) 0)
            x

            (= i 10)
            (recur x fx gx (g/* 2.0 eta) 0)

            :else
                (let [x-prime (g/- x (g/* eta gx))]
                  (if (<= (compare (g/abs (- x x-prime)) 1e-5) 0)
                      x
                      (let [fx-prime (f x-prime)]
                        (if (< (compare fx-prime fx) 0)
                          (recur x-prime fx-prime (g/transpose (g x-prime)) eta (+ i 1))
                          (recur x fx gx (/ eta 2.0) 0)))))))))

(defn naive-euler [w]
  (let [charges      [[10.0 (- 10.0 w)] [10.0 0.0]]
          x-initial    [0.0 8.0]
          xdot-initial [0.75 0.0]
          delta-t      1e-1
          p (fn [x]
              (transduce (map (fn [c] (/ 1.0 (g/abs (- x c)))))
                       g/+
                       0.0
                       charges))]
    (loop [x    x-initial
           xdot xdot-initial]
      (let [[x' y' :as x-new] (g/+ x (g/* delta-t xdot))]
            (if (g/negative? y')
          (let [delta-t-f (g// (g/- 0.0 (second x))
                                         (second xdot))
                      x-t-f     (g/+ x (g/* delta-t-f xdot))]
                  (g/square (first x-t-f)))
              (let [xddot (g/* -1.0 (g/transpose ((D p) x)))]
            (recur x-new (g/+ xdot (g/* delta-t xddot)))))))))

(defn multivariate-argmax [f x]
  (multivariate-argmin (fn [x] (g/- (f x))) x))

(defn multivariate-max [f x]
  (f (multivariate-argmax f x)))

(defn particle-FF []
  (multivariate-argmin naive-euler 0.0))

(defn saddle-FF []
  (let [start [1.0 1.0]
          f     (fn [[x1 y1] [x2 y2]]
                  (- (+ (g/square x1) (g/square y1))
                   (+ (g/square x2) (g/square y2))))
          x1* (multivariate-argmin
               (fn [x1]
                 (multivariate-max
                (fn [x2] (f x1 x2))
                  start))
               start)
          x2* (multivariate-argmax
               (fn [x2] (f x1* x2)) start)]
    [x1* x2*]))

And the results (with nothing to compare it to, since this is a new M1 mac and I haven't run any other benchmarks here):

sicmutils.env> (time (saddle-FF))
"Elapsed time: 9141.064291 msecs"
[(up 8.246324826140356E-6 8.246324826140356E-6) (up 8.246324826140356E-6 8.246324826140356E-6)]

sicmutils.env> (time (particle-FF))
"Elapsed time: 28148.853958 msecs"
0.20719187464861197
qobi commented 3 years ago

J.M. Siskind and B.A. Pearlmutter, `Efficient Implementation of a Higher-Order Language with Built-In AD,' Extended abstract, presented at the International Conference on Algorithmic Differentiation (AD), Oxford, UK, 12-15 September 2016. \url{http://engineering.purdue.edu/~qobi/papers/ad2016b.pdf}

also

\url{http://arxiv.org/abs/1611.03416}

Jeff (http: //engineering.purdue.edu/~qobi)
qobi commented 3 years ago

Another benchmark (a better variant of saddle) is in:

A. Radul, B.A. Pearlmutter, and J.M. Siskind, `AD in Fortran: Implementation via Prepreprocessor,' Proceedings of the 6th International Conference on Automatic Differentiation (AD2012), pp. 273-284, Fort Collins, CO, 23-27 July 2012. http://engineering.purdue.edu/~qobi/papers/ad2012.pdf

A. Radul, B.A. Pearlmutter, and J.M. Siskind, `AD in Fortran, Part 1: Design,' arXiv:1203.1448, 7 March 2012. http://arxiv.org/abs/1203.1448

A. Radul, B.A. Pearlmutter, and J.M. Siskind, `AD in Fortran, Part 2: Implementation via Prepreprocessor,' arXiv:1203.1450, 7 March 2012. http://arxiv.org/abs/1203.1450

Jeff (http: //engineering.purdue.edu/~qobi)
sritchie commented 3 years ago

Thanks @qobi !