shamazmazum / cl-forward-diff

Automatic differentiation for Common Lisp (forward mode)
BSD 2-Clause "Simplified" License
7 stars 0 forks source link

cl-forward-diff

CI

Manual

cl-forward-diff is a Common Lisp system which provides automatic differentiation in forward mode.

Since Common Lisp does not provide a way to define new parameteric types, all calculations are performed for the type dual which is an alias for (sb-ext:simd-pack double-float). These are a constuctor and destructors for convenience:

The package cl-forward-diff provides the following mathematical functions which you must use in the functions you want to differentiate:

These functions operate on type ext-number which is an abbreviation for (or dual real) which means that you cannot differentiate complex functions.

Inverse trigonometric functions (asin, acos, etc.) are not yet implemented.

These functions defined or behave differently compared to their counterparts in cl package:

Function Difference
sqrt Returns run-time error when called with negative argument.
log Does not have optional argument. Also see sqrt.

Functions to de differentiated are bestly defined within a package which shadows math functions from cl package with ones from cl-forward-diff package. See an example:

(defpackage test
  (:use #:cl)
  #.(cl-forward-diff:shadowing-import-math)
  (:export #:fn #:fn2 #:fn3))
(in-package :test)

(defun fn (x)
  (1+ (* (expt (1- x) 2) 2)))

(defun fn2 (args)
  (let ((x (aref args 0))
        (y (aref args 1)))
    (* (1- x) (1+ y))))

(defun fn3 (coeffs x)
  (reduce #'+
          (snakes:generator->list
           (snakes:imap
            (lambda (c n)
              (* c (expt x n)))
            (snakes:list->generator coeffs)
            (snakes:icount 0)))))

You can now calculate the first derivative of fn. Suppose you are in cl-user package. Type this in REPL:

CL-USER> (test:fn #d(4d0 1d0))
#<SIMD-PACK 1.9000000000000d+1 1.2000000000000d+1>

The returned pair contains the value of fn and its first derivative at the point 4.0. Remember that all calculations are permormed with double-float values and every constant in definition of fn is coerced to double-float. There is a helper function to calculate the first derivative: ad-univariate. Check this:

CL-USER> (cl-forward-diff:ad-univariate #'test:fn 4)
12.0d0

You can calculate gradient of a function of two or more variables, like test:fn2, using ad-multivariate.

CL-USER> (cl-forward-diff:ad-multivariate
           #'test:fn2 (cl-forward-diff:to-doubles '(2 4)))
#(5.0d0 1.0d0)

Rather complicated functions also can be differentiated:

CL-USER> (cl-forward-diff:ad-univariate (alexandria:curry #'test:fn3 '(1 2 1)) 5)
12.0d0

How to define piecewise functions?

Since differentiable functions must operate on dual numbers and dual numbers do not have order, you may ask: how to define piecewise functions like the following one?

(defun foo (x)
  (if (> x 2) x (* 3 x)))

It's easy. Just compare real part of x in the conditional form.

(mapcar
 (alexandria:curry
  #'cl-forward-diff:ad-univariate
  (lambda (x)
    (if (> (cl-forward-diff:dual-realpart x) 2)
        x (cl-forward-diff:* 3 x))))
 '(1 4))

evaluates to (3.0d0 1.0d0).

Optimization

If you want to write performat code, stick to these rules:

  1. Use (declare (optimize (speed 3))) in functions you want to optimize.
  2. Numerical arguments to differentiable functions must be of type dual. In case of using ad-multivariate the function must accept a simple array of duals.
  3. Differentiable functions must return one value of type dual. Again, you may wish to use toplevel declarations to make hints to the compiler, like (serapeum:-> fn (dual) (values dual &optional)).

Global side effects

Reader macro #D is added to the readtable. It allows you to create dual number literals like this #D(3d0 1d0) which is a shortcut for (make-dual 3d0 1d0).

Discussion (in the form of FAQ)