pyxu-org / pyxu

Modular and scalable computational imaging in Python with GPU/out-of-core computing.
https://pyxu-org.github.io/
MIT License
117 stars 17 forks source link

Add Support for Automatic Differentiation via JAX #9

Closed matthieumeo closed 1 year ago

matthieumeo commented 3 years ago

Add support for automatic differentiation via JAX. JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

Example:

from jax._src.dlpack import to_dlpack
import numpy as np
import cupy as cp
import dask.array as da
import jax.numpy as jnp
import types
from typing import Callable, Optional
from jax import jacfwd
import jax.dlpack as jxdl
from warnings import warn

def infer_array_module(decorated_object_type='method'):
  def make_decorator(call_fun: Callable):
    def wrapper(*args, **kwargs):
      if decorated_object_type == 'method':
        arr = args[1] #First argument is self
      else:
        arr = args[0]

      if isinstance(arr, cp.ndarray):
        xp = cp
      elif isinstance(arr, da.core.Array):
        xp = da 
      elif isinstance(arr, jnp.ndarray):
        xp = jnp
      else:
        xp = np # Fall back to Numpy backend if unknown array type
      kwargs['_xp'] = xp
      return call_fun(*args, **kwargs)
    return wrapper
  return make_decorator

class FFTOp(object):

  @infer_array_module(decorated_object_type='method')
  def __call__(self, x, _xp: Optional[types.ModuleType] = None):
    return _xp.fft.fft(x)

  @infer_array_module(decorated_object_type='method')
  def jacobian(self, x, _xp: Optional[types.ModuleType] = None):
    if _xp == cp:
      arr = jxdl.from_dlpack(x.astype(_xp.float32).toDlpack()) # Zero-copy conversion from Cupy to JAX arrays only works with float32 dtypes. 
      warn('Automatic differentiation with Cupy arrays only works with float32 precision.')
    elif _xp == da:
      raise NotImplementedError('Automatic differentiation does not support with lazy Dask arrays.')
    else:
      arr = jnp.asarray(x)
    jaxobian_eval = jacfwd(self.__call__)(arr)
    return _xp.asarray(jaxobian_eval)

Notes:

  1. Automatic differentiation is an eager transformation and therefore won't work on lazy Dask arrays.
  2. Zero-copy conversion from Cupy to JAX arrays seems to only work with float32 dtypes (to be investigated).
  3. The size of the nonlinear map should be checked to determine whether the Jacobian matrix should be computed row-by-row (jax.jacrev) of column-by-column (jax.jacfwd).
  4. Investigate jax.jvp or jax.vjp for evaluations of Jacobian-vector products without forming the Jacobian.
  5. JAX arrays are created by default on the first GPU available. When copying from Numpy arrays stored on the CPU, one should control explicitly where the JAX array is stored to avoid unnecessary data transfers between CPU and GPU.
SepandKashani commented 1 year ago

This feature has been merged into v2-dev.