Open sritchie opened 2 years ago
Here was the exploration that led to Siskind pointing us to the paper above:
A friend and I were fiddling around with nth-order symbolic derivatives, trying to figure out how to speed things up by interleaving a simplification step between each derivative.
It didn't help much, and the reason became clear when I removed the EXTRACT-DX-PART call from the derivative routine. For an nth order derivative (n=3D=3D3 here):
(let ((f (literal-function 'f)))
(((expt D 3) f) 'x))
The differential object's term list has an entry for every element of the power set of (IOTA N)... and there is a bunch of repeated work! Every k-th derivative, for k <=3D n is calculated k-choose-n times, because of course every tag refers to the smae function:
[[[] (f x)]
[[123] ((D f) x)]
[[124] ((D f) x)]
[[125] ((D f) x)]
[[123 124] (((expt D 2) f) x)]
[[123 125] (((expt D 2) f) x)]
[[124 125] (((expt D 2) f) x)]
[[123 124 125] (((expt D 3) f) x)]]
The puzzle is =E2=80=94 how to skip all of the duplicates? I have some thou= ghts and some ideas that failed that we can talk about next time we chat, if this is interesting. It would be nice to get an exponential speedup here.
How slow is it in practice?
Given some function like G:
(define (g x)
(* (expt x 5)
(cos (* (sqrt x) (+ x 3)))))
The 7th derivative + a SIMPLIFY step takes 11 seconds, while the 8th derivative + SIMPLIFY takes 6 minutes:
(show-time (lambda () (simplify
(((expt D 7) g)
'x))))
;process time: 11370 (10740 RUN + 630 GC); real time: 12323#
(show-time (lambda () (simplify
(((expt D 8) g)
'x))))
;process time: 332870 (101980 RUN + 230890 GC); real time: 349152#
A New Hope If you bring in EVAL and interleave simplification steps between each derivative, you CAN skip all of the repeated computation and get some massive savings. Here are timings for the 7th, 8th, 20th and 50th order derivatives of G:
;; 7:
;process time: 160 (160 RUN + 0 GC); real time: 286#|
;; n=3D8:
;process time: 180 (180 RUN + 0 GC); real time: 191#|
;; n=3D20:
;process time: 1420 (1380 RUN + 40 GC); real time: 1666#|
;; n=3D50:
;process time: 17750 (17390 RUN + 360 GC); real time: 18333#|
How to do it? First, write a function that that compiles a procedure down into a new procedure that in-lines a simplification step:
(define (simplify-compile f)
(let* ((sym (generate-uninterned-symbol "x"))
(body (simplify (f sym))))
(eval `(lambda (,sym) ,body)
generic-environment)))
Then write a new version of the D operator that internally
compiles-and-simplifies (D f)
:
(define Ds
(make-operator
(lambda (f)
(simplify-compile (D f)))
'Ds))
For symbolic derivatives, this is much faster, as it doesn't repeat any of the work:
(show-time (lambda () (simplify
(((expt D 50) g)
'x))))
;process time: 17750 (17390 RUN + 360 GC); real time: 18333#
Here's a page with the result if you want to see it: https://gist.github.com/sritchie/cda11a1c0fe6415e71eaf65c3df506ff
Can we do it without eval?
@alexgian this is the solution to the problem you've discovered! the potential for a fix is even bigger than we thought. We were looking at the special case of higher-order derivatives on the same argument... but actually (* (partial 0) (partial 1))
is identical to (* (partial 1) (partial 0))
; any permutation of the same indices has the same derivative.
@qobi notes:
There is symmetry in the higher-order derivatives (d/dxdy=d/dydx). There are still an exponential number even after removing symmetry. That said, there are systems that avoid recomputing and storing redundant symmetric derivatives. See, for example, Mu Wang's papers and PhD thesis.
Mu (together with his advisor Alex Pothen) did this for reverse mode. There were people who did this before for forward mode. I forget who. Barak might remember and be able to give you the reverse. Mu's thesis and Mu and Pothen's papers probably also have the reference.
Pearlmutter adds:
Griewank had a list of people who'd worked out how to exploit the symmetries for the forward case. I think he sent up a brain dump for that NIPS workshop...
Thanks @qobi for the reference.
This is the solution to the problem of an exponential number of derivatives being calculated for
(((expt D n) f) 'x)
== the size of the power set of a set of sizen
(see the comment below).The task is to implement the method described in this paper: https://engineering.purdue.edu/~qobi/papers/popl2007a.pdf
And THEN to figure out how to expose this behavior inside the
D
operator as it exists currently.References