Add support for automatic differentiation via JAX. JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
Example:
from jax._src.dlpack import to_dlpack
import numpy as np
import cupy as cp
import dask.array as da
import jax.numpy as jnp
import types
from typing import Callable, Optional
from jax import jacfwd
import jax.dlpack as jxdl
from warnings import warn
def infer_array_module(decorated_object_type='method'):
def make_decorator(call_fun: Callable):
def wrapper(*args, **kwargs):
if decorated_object_type == 'method':
arr = args[1] #First argument is self
else:
arr = args[0]
if isinstance(arr, cp.ndarray):
xp = cp
elif isinstance(arr, da.core.Array):
xp = da
elif isinstance(arr, jnp.ndarray):
xp = jnp
else:
xp = np # Fall back to Numpy backend if unknown array type
kwargs['_xp'] = xp
return call_fun(*args, **kwargs)
return wrapper
return make_decorator
class FFTOp(object):
@infer_array_module(decorated_object_type='method')
def __call__(self, x, _xp: Optional[types.ModuleType] = None):
return _xp.fft.fft(x)
@infer_array_module(decorated_object_type='method')
def jacobian(self, x, _xp: Optional[types.ModuleType] = None):
if _xp == cp:
arr = jxdl.from_dlpack(x.astype(_xp.float32).toDlpack()) # Zero-copy conversion from Cupy to JAX arrays only works with float32 dtypes.
warn('Automatic differentiation with Cupy arrays only works with float32 precision.')
elif _xp == da:
raise NotImplementedError('Automatic differentiation does not support with lazy Dask arrays.')
else:
arr = jnp.asarray(x)
jaxobian_eval = jacfwd(self.__call__)(arr)
return _xp.asarray(jaxobian_eval)
Notes:
Automatic differentiation is an eager transformation and therefore won't work on lazy Dask arrays.
Zero-copy conversion from Cupy to JAX arrays seems to only work with float32 dtypes (to be investigated).
The size of the nonlinear map should be checked to determine whether the Jacobian matrix should be computed row-by-row (jax.jacrev) of column-by-column (jax.jacfwd).
Investigate jax.jvp or jax.vjp for evaluations of Jacobian-vector products without forming the Jacobian.
JAX arrays are created by default on the first GPU available. When copying from Numpy arrays stored on the CPU, one should control explicitly where the JAX array is stored to avoid unnecessary data transfers between CPU and GPU.
Add support for automatic differentiation via JAX. JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
Example:
Notes:
jax.jacrev
) of column-by-column (jax.jacfwd
).jax.jvp
orjax.vjp
for evaluations of Jacobian-vector products without forming the Jacobian.