mentat-collective / emmy

The Emmy Computer Algebra System.
https://emmy.mentat.org
GNU General Public License v3.0
405 stars 24 forks source link

Implement "Lazy Multivariate Higher-Order Forward-Mode AD" for cheap higher-order derivatives #59

Open sritchie opened 2 years ago

sritchie commented 2 years ago

Thanks @qobi for the reference.

This is the solution to the problem of an exponential number of derivatives being calculated for (((expt D n) f) 'x) == the size of the power set of a set of size n (see the comment below).

The task is to implement the method described in this paper: https://engineering.purdue.edu/~qobi/papers/popl2007a.pdf

A method is presented for computing all higher-order partial derivatives of a multivariate function R^n → R. This method works by evaluating the function under a nonstandard interpretation, lifting reals to multivariate power series. Multivariate power series, with potentially an infinite number of terms with nonzero coefficients, are represented using a lazy data structure constructed out of linear terms. A complete implementation of this method in SCHEME is presented, along with a straightforward exposition, based on Taylor expansions, of the method’s correctness.

And THEN to figure out how to expose this behavior inside the D operator as it exists currently.

References

sritchie commented 2 years ago

Here was the exploration that led to Siskind pointing us to the paper above:

A friend and I were fiddling around with nth-order symbolic derivatives, trying to figure out how to speed things up by interleaving a simplification step between each derivative.

It didn't help much, and the reason became clear when I removed the EXTRACT-DX-PART call from the derivative routine. For an nth order derivative (n=3D=3D3 here):

(let ((f (literal-function 'f)))
  (((expt D 3) f) 'x))

The differential object's term list has an entry for every element of the power set of (IOTA N)... and there is a bunch of repeated work! Every k-th derivative, for k <=3D n is calculated k-choose-n times, because of course every tag refers to the smae function:

[[[]            (f x)]
 [[123]         ((D f) x)]
 [[124]         ((D f) x)]
 [[125]         ((D f) x)]
 [[123 124]     (((expt D 2) f) x)]
 [[123 125]     (((expt D 2) f) x)]
 [[124 125]     (((expt D 2) f) x)]
 [[123 124 125] (((expt D 3) f) x)]]

The puzzle is =E2=80=94 how to skip all of the duplicates? I have some thou= ghts and some ideas that failed that we can talk about next time we chat, if this is interesting. It would be nice to get an exponential speedup here.

How slow is it in practice?

Given some function like G:

(define (g x)
  (* (expt x 5)
     (cos (* (sqrt x) (+ x 3)))))

The 7th derivative + a SIMPLIFY step takes 11 seconds, while the 8th derivative + SIMPLIFY takes 6 minutes:

(show-time (lambda () (simplify
                        (((expt D 7) g)
                         'x))))
;process time: 11370 (10740 RUN + 630 GC); real time: 12323#
(show-time (lambda () (simplify
                        (((expt D 8) g)
                         'x))))
;process time: 332870 (101980 RUN + 230890 GC); real time: 349152#

A New Hope If you bring in EVAL and interleave simplification steps between each derivative, you CAN skip all of the repeated computation and get some massive savings. Here are timings for the 7th, 8th, 20th and 50th order derivatives of G:

;; 7:
;process time: 160 (160 RUN + 0 GC); real time: 286#|
;; n=3D8:
;process time: 180 (180 RUN + 0 GC); real time: 191#|
;; n=3D20:
;process time: 1420 (1380 RUN + 40 GC); real time: 1666#|
;; n=3D50:
;process time: 17750 (17390 RUN + 360 GC); real time: 18333#|

How to do it? First, write a function that that compiles a procedure down into a new procedure that in-lines a simplification step:

(define (simplify-compile f)
  (let* ((sym  (generate-uninterned-symbol "x"))
         (body (simplify (f sym))))
    (eval `(lambda (,sym) ,body)
          generic-environment)))

Then write a new version of the D operator that internally compiles-and-simplifies (D f) :

(define Ds
  (make-operator
   (lambda (f)
     (simplify-compile (D f)))
   'Ds))

For symbolic derivatives, this is much faster, as it doesn't repeat any of the work:

(show-time (lambda () (simplify
                        (((expt D 50) g)
                         'x))))
;process time: 17750 (17390 RUN + 360 GC); real time: 18333#

Here's a page with the result if you want to see it: https://gist.github.com/sritchie/cda11a1c0fe6415e71eaf65c3df506ff

Can we do it without eval?

sritchie commented 2 years ago

@alexgian this is the solution to the problem you've discovered! the potential for a fix is even bigger than we thought. We were looking at the special case of higher-order derivatives on the same argument... but actually (* (partial 0) (partial 1)) is identical to (* (partial 1) (partial 0)); any permutation of the same indices has the same derivative.

@qobi notes:

There is symmetry in the higher-order derivatives (d/dxdy=d/dydx). There are still an exponential number even after removing symmetry. That said, there are systems that avoid recomputing and storing redundant symmetric derivatives. See, for example, Mu Wang's papers and PhD thesis.

Mu (together with his advisor Alex Pothen) did this for reverse mode. There were people who did this before for forward mode. I forget who. Barak might remember and be able to give you the reverse. Mu's thesis and Mu and Pothen's papers probably also have the reference.

Pearlmutter adds:

Griewank had a list of people who'd worked out how to exploit the symmetries for the forward case. I think he sent up a brain dump for that NIPS workshop...