jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Inverse Accumulation Mode #12494

Open NeilGirdhar opened 2 years ago

NeilGirdhar commented 2 years ago

Inverted Jacobian products are useful in a variety of algorithms such as the efficient implementation of Newton's method with regularization. However, Jax currently only provides non-inverted Jacobian products (jvp and vjp).

It appears that it is possible to efficiently implement inverted Jacobian products in an automatic differentiation library like Jax thanks to a recent paper:

Siskind, Jeffrey Mark. "Automatic Differentiation: Inverse Accumulation Mode." (2019).

The interface for the inverted Jacobians could be something like:

Using these, we could also produce inverse_jacfwd and inverse_jacrev, one of which could be mapped to inverse_jacobian.

Has anyone on the Jax team looked into this?

patrick-kidger commented 2 years ago

It looks like the main approach proposed (top of page 2) is based around scalar operations. Applied to a forward computation of a matrix-vector product x -> Ax, then I think the proposed (inverse-Jacobian)-vector product, i.e. b -> A^{-1} b, is then arithmetically equivalent to performing the linear solve via Gaussian elimination.

I'm not sure I see an application for this in practice: for what computational structures would this beat a traditional linear solve? (Dense Jacobians: I think a standard LU or CG solver would probably be best. Sparse Jacobians: probably the sparse linear solver does more-or-less the same operations as are proposed.)

NeilGirdhar commented 2 years ago

(Dense Jacobians: I think a standard LU or CG solver would probably be best.

I read the paper to mean that they can apply their algorithm in time complexity similar to the non-inverted forward and reverse mode. Therefore, much faster than solving for the inverse Jacobian.

patrick-kidger commented 2 years ago

Ah, right: the inversion they're proposing uses the intermediate values saved in additional scratch space. This brings the computational cost of the ijvp down from O(n^3) (Gaussian elimination) to O(n^2) (Cost of a matrix-vector product; which is the jvp.)

That is, we can't just consider a matrix-vector product x -> Ax (computing the ijvp of which has the traditional O(n^3) asymptotics). We have to consider (x, A) -> (Ax, additional stuff) to get O(n^2) asymptotics for the ijvp. In this case I think (x, A) -> (Ax, (Aij xj)_{i, j}) would suffice.

I still don't think this helps in most scenarios, e.g. in Newton's method. Typically the value of the additional scratch space is unknown, so the efficient reversibility isn't possible.

jakevdp commented 2 years ago

One note: @apaszke added some experimental tools around these same ideas, but removed them earlier this year (https://github.com/google/jax/commit/902fc0c3d2b3ec9b6034c66074984386ec35606f) because they weren't being used.

PhilipVinc commented 2 years ago

I think this might be very useful to compute the inverse of the Fisher Information Matrix / Geometric Tensor, needed to perform Natural gradient descent...

NeilGirdhar commented 2 years ago

I think this might be very useful to compute the inverse of the Fisher Information Matrix

That's exactly what I want it for. I'm not sure it buys any time though since you need to call it n times unless you're approximating.

mattjj commented 2 years ago

Thanks to the inverse function theorem, another way to compute the same thing is to compose jvp, vjp, jacfwd, or jacrev with the oryx.core.inverse transformation. That is, for f : R^n \to R^n, we have x \mapsto ∂ f^{-1}(x) == x \mapsto inv(∂f(x)) pointwise, where inv is meant to denote dense matrix inverse. Actually I'd be interested if that ends up generating the same computation as discussed in that paper (which I admit not to have read yet, beyond the abstract!).

I think jacfwd-of-oryx.core.inverse-of-f would in general be a fairly different computation than jnp.linalg.inverse-of-jacfwd-of-f because the former would exploit sparsity structure represented in the program and its dataflow, whereas the latter would just be operating on a dense matrix.

qobi commented 2 years ago

Indeed the advantage of our approach is that it is compositional and thus has running time proportional to the primal. But like ordinary reverse mode, it requires a tape whose size is proportional to the the running time.

There is no need to do Gaussian elimination. The only inversion is that of the scalar a in (5a right) because steps (4) involve only unary and binary operations.

The catch is that each step must be what we call "equasive"; it must have the same number of inputs and outputs. (We call a step that has more outputs than inputs "expansive" and a step that has fewer outputs than inputs "contractive".) Equasive steps have square Jacobians; Expansive and Contractive steps do not. Non-square Jacobians are not invertible.

It is possible that the overall computation is equasive but the individual steps are not. If the dimension of the intermediate state is ever smaller than the input/output dimension, then the Jacobian is not invertible. But it is possible that as the computation progresses, the dimension increases and then decreases, perhaps more than once, never going below the input/output dimension. It is further possible that the dimension varies over the course of the computation (alway being above the input/output dimension) but at one or more intermediate point returns to the input/output dimension. If this is the case, it is possible to split the computation at those points into a sequence of equasive chunks we call "lumps". It is then possible to apply the method to that sequence of lumps. When doing this, the lumps would no longer be steps consisting of unary and binary operations. Thus instead of (5a) you have (5b) where in (5b right) you have to invert A. The saving grace is that the dimension of A is likely to be (much) smaller than the input/output dimension of the whole problem.

We implemented the stepwise equasive variant in a variant of R6RS-AD.

https://github.com/qobi/R6RS-AD

(Note that JAX is based on HIPS Autograd which is based on R6RS-AD. R6RS-AD predates HIPS Autograd by about 7 years and predates JAX by about a decade. R6RS-AD was used in

@inproceedings{nips2011, author = {D. Wingate and N. Goodman and A. Stuhlm{\"{u}}ller and J. M. Siskind}, title = {Nonstandard Interpretations of Probabilistic Programs for Efficient Inference}, booktitle = nips, location = {Granada, Spain}, day = {12--15}, month = dec, year = 2011, url = {http://engineering.purdue.edu/~qobi/papers/nips2011.pdf}} )

The implementation is straightforward and enclosed. We worked on methods to automatically divide an arbitrary computation graph into lumps (what we call "lumpification"). But that work is not complete. It is complicated because the dimension of the intermediate state depends on how you schedule the operations. Thus the possible lumpifications depend on scheduling. It appears to be NP hard to optimally schedule to minimize the dimension of the A of the maximal lump.

The link to stackexchange discusses the Moore-Penrose inverse of non-square Jacobians. We spent some time investigating this, as well as a variety of other pseudoinverses besides the Moore-Penrose pseudoinverse. We are unaware of any pseudoinverse that is compositional. Compositonality is required to make (3) work. There might be one that is compositional (and useful) that we are unaware of. It also might be the case that a product of Moore-Penrose (or other) pseudoinverse, while not preserving the properties of that pseudoinverse, might still be useful. We never got very far along this line of investigation.

#!r6rs

(library
  (tape-AD)
 (export (rename (d+ +))
     (rename (d- -))
     (rename (d* *))
     (rename (d/ /))
     (rename (dsqrt sqrt))
     (rename (dexp exp))
     (rename (dlog log))
     (rename (dexpt expt))
     (rename (dsin sin))
     (rename (dcos cos))
     (rename (datan atan))
     (rename (d= =))
     (rename (d< <))
     (rename (d> >))
     (rename (d<= <=))
     (rename (d>= >=))
     (rename (dzero? zero?))
     (rename (dpositive? positive?))
     (rename (dnegative? negative?))
     (rename (dreal? real?))
     write-real
     j*
     *j
     j*^-1
     *j^-1)
 (import (rnrs))

 (define-record-type tape
  (fields primal
      factors
      tapes
      (mutable fanout)
      (mutable co/tangent)))

 (define (new-tape primal factors tapes)
  (make-tape primal factors tapes 0 0))

 (define (tapify x) (new-tape x '() '()))

 (define (lift-real->real f df/dx)
  (letrec ((self (lambda (x)
          (if (tape? x)
              (new-tape (self (tape-primal x))
                (list (df/dx (tape-primal x)))
                (list x))
              (f x)))))
   self))

 (define (lift-real*real->real f df/dx1 df/dx2)
  (letrec ((self
        (lambda (x1 x2)
         (if (tape? x1)
         (if (tape? x2)
             (new-tape (self (tape-primal x1) (tape-primal x2))
                   (list (df/dx1 (tape-primal x1) (tape-primal x2))
                     (df/dx2 (tape-primal x1) (tape-primal x2)))
                   (list x1 x2))
             (new-tape (self (tape-primal x1) x2)
                   (list (df/dx1 (tape-primal x1) x2))
                   (list x1)))
         (if (tape? x2)
             (new-tape (self x1 (tape-primal x2))
                   (list (df/dx2 x1 (tape-primal x2)))
                   (list x2))
             (f x1 x2))))))
   self))

 (define first car)

 (define second cadr)

 (define rest cdr)

 (define (fold f l)
  (let loop ((l (cdr l)) (c (first l)))
   (if (null? l) c (loop (rest l) (f c (first l))))))

 (define (count-if p l)
  (let loop ((l l) (c 0))
   (cond ((null? l) c)
     ((p (first l)) (loop (rest l) (+ c 1)))
     (else (loop (rest l) c)))))

 (define (map-reduce g i f l . ls)
  (if (null? l)
      i
      (apply map-reduce
         g
         (g i (apply f (first l) (map first ls)))
         f
         (rest l)
         (map rest ls))))

 (define (list-remove-ith l i)
  (if (zero? i) (rest l) (cons (first l) (list-remove-ith (rest l) (- i 1)))))

 (define (position-if p l)
  (let loop ((l l) (i 0))
   (cond ((null? l) #f)
     ((p (first l)) i)
     (else (loop (rest l) (+ i 1))))))

 (define (lift-real^n->real f df/dx1 df/dx2)
  (lambda xs
   (if (null? xs) (f) (fold (lift-real*real->real f df/dx1 df/dx2) xs))))

 (define (lift-real^n+1->real f df/dx df/dx1 df/dx2)
  (lambda xs
   (cond ((null? xs) (f))
     ((null? (rest xs)) ((lift-real->real f df/dx) (first xs)))
     (else (fold (lift-real*real->real f df/dx1 df/dx2) xs)))))

 (define (primal* x) (if (tape? x) (primal* (tape-primal x)) x))

 (define (lift-real^n->boolean f) (lambda xs (apply f (map primal* xs))))

 (define d+ (lift-real^n->real + (lambda (x1 x2) 1) (lambda (x1 x2) 1)))

 (define d- (lift-real^n+1->real
         - (lambda (x) -1) (lambda (x1 x2) 1) (lambda (x1 x2) -1)))

 (define d* (lift-real^n->real * (lambda (x1 x2) x2) (lambda (x1 x2) x1)))

 (define d/ (lift-real^n+1->real
         /
         (lambda (x) (d- (d/ (d* x x))))
         (lambda (x1 x2) (d/ x2))
         (lambda (x1 x2) (d- (d/ x1 (d* x2 x2))))))

 (define dsqrt (lift-real->real sqrt (lambda (x) (d/ (d* 2 (dsqrt x))))))

 (define dexp (lift-real->real exp (lambda (x) (dexp x))))

 (define dlog (lift-real->real log (lambda (x) (d/ x))))

 (define dexpt
  (lift-real*real->real expt
            (lambda (x1 x2) (d* x2 (dexpt x1 (d- x2 1))))
            (lambda (x1 x2) (d* (dlog x1) (dexpt x1 x2)))))

 (define dsin (lift-real->real sin (lambda (x) (dcos x))))

 (define dcos (lift-real->real cos (lambda (x) (d- (dsin x)))))

 (define (datan . xs)
  (cond ((null? xs) (apply atan xs))
    ((null? (rest xs)) (datan (first xs) 1))
    ((null? (rest (rest xs)))
     ((lift-real*real->real
       atan
       (lambda (x1 x2) (d/ x2 (d+ (d* x1 x1) (d* x2 x2))))
       (lambda (x1 x2) (d/ (d- x1) (d+ (d* x1 x1) (d* x2 x2)))))
      (first xs)
      (second xs)))
    (else (apply atan xs))))

 (define d= (lift-real^n->boolean =))

 (define d< (lift-real^n->boolean <))

 (define d> (lift-real^n->boolean >))

 (define d<= (lift-real^n->boolean <=))

 (define d>= (lift-real^n->boolean >=))

 (define dzero? (lift-real^n->boolean zero?))

 (define dpositive? (lift-real^n->boolean positive?))

 (define dnegative? (lift-real^n->boolean negative?))

 (define dreal? (lift-real^n->boolean real?))

 (define (write-real x)
  (cond ((tape? x) (write-real (tape-primal x)) x)
    (else (write x) (newline) x)))

 (define (determine-fanout! tape)
  (tape-fanout-set! tape (+ (tape-fanout tape) 1))
  (when (= (tape-fanout tape) 1)
   (for-each determine-fanout! (tape-tapes tape))))

 (define (initialize-co/tangent! tape)
  (tape-co/tangent-set! tape 0)
  (tape-fanout-set! tape (- (tape-fanout tape) 1))
  (when (zero? (tape-fanout tape))
   (for-each initialize-co/tangent! (tape-tapes tape))))

 (define (forward-accumulation-sweep tape)
  (if (null? (tape-tapes tape))
      (tape-co/tangent tape)
      (map-reduce d+
          0
          (lambda (factor tape)
           (d* factor (forward-accumulation-sweep tape)))
          (tape-factors tape)
          (tape-tapes tape))))

 (define (reverse-accumulation-sweep! tape)
  (when (zero? (tape-fanout tape))
   (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
         (tape-tapes tape))
   (let ((cotangent (tape-co/tangent tape)))
    (for-each (lambda (factor tape)
           (tape-co/tangent-set!
        tape (d+ (tape-co/tangent tape) (d* cotangent factor))))
          (tape-factors tape)
          (tape-tapes tape)))
   (for-each reverse-accumulation-sweep!(tape-tapes tape))))

 (define (stepwise-equasive-forward-inverse-accumulation-sweep tape)
  (cond
   ((null? (tape-tapes tape)) (tape-co/tangent tape))
   (else
    (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
          (tape-tapes tape))
    (unless (= (count-if (lambda (tape) (zero? (tape-fanout tape)))
             (tape-tapes tape))
           1)
     (error #f "Not equasive"))
    (let* ((i (position-if (lambda (tape) (zero? (tape-fanout tape)))
               (tape-tapes tape)))
       (factor-i (list-ref (tape-factors tape) i))
       (tape-i (list-ref (tape-tapes tape) i))
       (cotangent-i
        (stepwise-equasive-forward-inverse-accumulation-sweep tape-i)))
     ;;\needswork: commutativity
     (- (d/ cotangent-i factor-i)
    (map-reduce
     d+
     0
     (lambda (factor tape)
      ;;\needswork: commutativity
      (d/ (d* (stepwise-equasive-forward-inverse-accumulation-sweep tape)
          factor)
          factor-i))
     (list-remove-ith (tape-factors tape) i)
     (list-remove-ith (tape-tapes tape) i)))))))

 (define (stepwise-equasive-reverse-inverse-accumulation-sweep! tape)
  (unless (null? (tape-tapes tape))
   (when (zero? (tape-fanout tape))
    (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
          (tape-tapes tape))
    (unless (= (count-if (lambda (tape) (zero? (tape-fanout tape)))
             (tape-tapes tape))
           1)
     (error #f "Not equasive"))
    (let* ((i (position-if (lambda (tape) (zero? (tape-fanout tape)))
               (tape-tapes tape)))
       (factor-i (list-ref (tape-factors tape) i))
       (tape-i (list-ref (tape-tapes tape) i))
       (tangent-i (tape-co/tangent tape)))
     (for-each
      (lambda (factor tape)
       (tape-co/tangent-set!
    tape
    ;;\needswork: commutativity
    (d- (tape-co/tangent tape) (d/ (d* tangent-i factor) factor-i))))
      (list-remove-ith (tape-factors tape) i)
      (list-remove-ith (tape-tapes tape) i))
     ;;\needswork: commutativity
     (tape-co/tangent-set! tape-i (d/ tangent-i factor-i))
     (for-each stepwise-equasive-reverse-inverse-accumulation-sweep!
           (tape-tapes tape))))))

 (define (map-walk1 f x)
  ;;\needswork: Not safe for space.
  (cond ((eq? x #t) #t)
    ((eq? x #f) #f)
    ((null? x) '())
    ((char? x) x)
    ((string? x) x)
    ((dreal? x) (f x))
    ((pair? x) (cons (map-walk1 f (car x)) (map-walk1 f (cdr x))))
    ;;\needswork: vectors
    (else (error #f "Not walkable"))))

 (define (map-walk2 f x x-prime)
  ;;\needswork: Not safe for space.
  (cond ((and (eq? x #t) (eq? x-prime #t)) #t)
    ((and (eq? x #f) (eq? x-prime #f)) #f)
    ((and (null? x) (null? x-prime)) '())
    ((and (char? x) (char? x-prime) (char=? x x-prime)) x)
    ((and (string? x) (string? x-prime) (string=? x x-prime)) x)
    ((and (dreal? x) (dreal? x-prime)) (f x x-prime))
    ((and (pair? x) (pair? x-prime))
     (cons (map-walk2 f (car x) (car x-prime))
           (map-walk2 f (cdr x) (cdr x-prime))))
    ;;\needswork: vectors
    (else (error #f "Values don't conform: ~s ~s" x x-prime))))

 (define (for-each-walk1! f x)
  ;;\needswork: Not safe for space.
  (cond ((eq? x #t) #f)
    ((eq? x #f) #f)
    ((null? x) #f)
    ((char? x) #f)
    ((string? x) #f)
    ((dreal? x) (f x))
    ((pair? x) (for-each-walk1! f (car x)) (for-each-walk1! f (cdr x)))
    ;;\needswork: vectors
    (else (error #f "Not walkable"))))

 (define (for-each-walk2! f x x-prime)
  ;;\needswork: Not safe for space.
  (cond ((and (eq? x #t) (eq? x-prime #t)) #f)
    ((and (eq? x #f) (eq? x-prime #f)) #f)
    ((and (null? x) (null? x-prime)) #f)
    ((and (char? x) (char? x-prime) (char=? x x-prime)) #f)
    ((and (string? x) (string? x-prime) (string=? x x-prime)) #f)
    ((and (dreal? x) (dreal? x-prime)) (f x x-prime))
    ((and (pair? x) (pair? x-prime))
     (for-each-walk2! f (car x) (car x-prime))
     (for-each-walk2! f (cdr x) (cdr x-prime)))
    ;;\needswork: vectors
    (else (error #f "Values don't conform: ~s ~s" x x-prime))))

 (define (tape-mode mode f x co/tangent)
  (let* ((x-tape (map-walk1 tapify x))
     (y-tape (f x-tape)))
   (case mode
    ((forward)
     (for-each-walk2! tape-co/tangent-set! x-tape co/tangent)
     (list (map-walk1 tape-primal y-tape)
       (map-walk1 forward-accumulation-sweep y-tape)))
    ((reverse)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1! initialize-co/tangent! y-tape)
     (for-each-walk2! tape-co/tangent-set! y-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     (for-each-walk1! reverse-accumulation-sweep! y-tape)
     (list (map-walk1 tape-primal y-tape)
       (map-walk1 tape-co/tangent x-tape)))
    ((forward-inverse)
     (for-each-walk2! tape-co/tangent-set! x-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     (list (map-walk1 tape-primal y-tape)
       ;; What you return here will depend on  whether it is stepwise
       ;; equasive, expansive, or contractive.
       (map-walk1 stepwise-equasive-forward-inverse-accumulation-sweep
              y-tape)))
    ((reverse-inverse)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1! initialize-co/tangent! y-tape)
     (for-each-walk2! tape-co/tangent-set! y-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     ;; What you do here will depend on  whether it is stepwise equasive,
     ;; expansive, or contractive.
     (for-each-walk1!
      stepwise-equasive-reverse-inverse-accumulation-sweep! y-tape)
     (list (map-walk1 tape-primal y-tape)
       (map-walk1 tape-co/tangent x-tape)))
    (else (error #f "Unknown mode")))))

 (define (j* f x x-tangent) (tape-mode 'forward f x x-tangent))

 (define (*j f x y-cotangent) (tape-mode 'reverse f x y-cotangent))

 (define (j*^-1 f x x-cotangent) (tape-mode 'forward-inverse f x x-cotangent))

 (define (*j^-1 f x y-tangent) (tape-mode 'reverse-inverse f x y-tangent)))
carlosgmartin commented 1 year ago

Any update on this?

In addition to Newton's method and natural gradient ascent, this could be useful for Local Sympletic Surgery (LSS) and Polymatrix Competitive Gradient Descent (PCGD).

dawidpasterny commented 2 weeks ago

Hi there, have one more usecase where it could be useful namely training inverse surrogates with physical gradient. Solving a linear system like Jx = y i.e. "computing inverse-jacobian-vector-product" equates to evaluating the action of the first order, local approximation of the inverse solver x = P^{-1}(y)