jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Remove unnecessary broadcast from `jnp.vectorize` #14751

Closed Gattocrucco closed 1 year ago

Gattocrucco commented 1 year ago

Summary: improvement to jnp.vectorize that reduces memory consumption and increases speed by a lot (100x in the example).

Currently jnp.vectorize partially broadcasts together the arguments before passing them to jax.vmap. As a comment says in the code:

# TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
# doing all vectorization with vmap? This *might* be a little more
# efficient but would require more careful book-keeping.

The might is an understatement. The performance can drop a great deal with explicit broadcasting if the arrays end up being very large and fill up the memory. I modified the implementation to fix this:

from jax._src.numpy.vectorize import (
  functools,
  _apply_excluded,
  map,
  jnp,
  _parse_gufunc_signature,
  _parse_input_dimensions,
  _check_output_dims,
  zip,
  api,
)

def modified_vectorize(pyfunc, *, excluded=frozenset(), signature=None):
  """Define a vectorized function with broadcasting.

  :func:`vectorize` is a convenience wrapper for defining vectorized
  functions with broadcasting, in the style of NumPy's
  `generalized universal functions <https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_.
  It allows for defining functions that are automatically repeated across
  any leading dimensions, without the implementation of the function needing to
  be concerned about how to handle higher dimensional inputs.

  :func:`jax.numpy.vectorize` has the same interface as
  :class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching
  transformation (:func:`vmap`) rather than a Python loop. This should be
  considerably more efficient, but the implementation must be written in terms
  of functions that act on JAX arrays.

  Args:
    pyfunc: function to vectorize.
    excluded: optional set of integers representing positional arguments for
      which the function will not be vectorized. These will be passed directly
      to ``pyfunc`` unmodified.
    signature: optional generalized universal function signature, e.g.,
      ``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
      provided, ``pyfunc`` will be called with (and expected to return) arrays
      with shapes given by the size of corresponding core dimensions. By
      default, pyfunc is assumed to take scalars arrays as input and output.

  Returns:
    Vectorized version of the given function.

  Here are a few examples of how one could write vectorized linear algebra
  routines using :func:`vectorize`:

  >>> from functools import partial

  >>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
  ... def cross_product(a, b):
  ...   assert a.shape == b.shape and a.ndim == b.ndim == 1
  ...   return jnp.array([a[1] * b[2] - a[2] * b[1],
  ...                     a[2] * b[0] - a[0] * b[2],
  ...                     a[0] * b[1] - a[1] * b[0]])

  >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
  ... def matrix_vector_product(matrix, vector):
  ...   assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
  ...   return matrix @ vector

  These functions are only written to handle 1D or 2D arrays (the ``assert``
  statements will never be violated), but with vectorize they support
  arbitrary dimensional inputs with NumPy style broadcasting, e.g.,

  >>> cross_product(jnp.ones(3), jnp.ones(3)).shape
  (3,)
  >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
  (2, 3)
  >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
  (2, 2, 3)
  >>> matrix_vector_product(jnp.ones(3), jnp.ones(3))  # doctest: +IGNORE_EXCEPTION_DETAIL
  Traceback (most recent call last):
  ValueError: input with shape (3,) does not have enough dimensions for all
  core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
  and signature='(n,k),(k)->(k)'
  >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
  (2,)
  >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
  (4, 2)

  Note that this has different semantics than `jnp.matmul`:

  >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3)))  # doctest: +IGNORE_EXCEPTION_DETAIL
  Traceback (most recent call last):
  TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
  """
  if any(not isinstance(exclude, int) for exclude in excluded):
    raise TypeError("jax.numpy.vectorize can only exclude integer arguments, "
                    "but excluded={!r}".format(excluded))
  if excluded and min(excluded) < 0:
    raise ValueError(f"excluded={excluded!r} contains negative numbers")

  @functools.wraps(pyfunc)
  def wrapped(*args):
    error_context = ("on vectorized function with excluded={!r} and "
                     "signature={!r}".format(excluded, signature))
    excluded_func, args = _apply_excluded(pyfunc, excluded, args)
    args = tuple(map(jnp.asarray, args))

    if signature is not None:
      input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
    else:
      input_core_dims = [()] * len(args)
      output_core_dims = None

    broadcast_shape, dim_sizes = _parse_input_dimensions(
        args, input_core_dims, error_context)

    checked_func = _check_output_dims(
        excluded_func, dim_sizes, output_core_dims, error_context)

    # Rather than broadcasting all arguments to full broadcast shapes, prefer
    # expanding dimensions using vmap. By pushing broadcasting
    # into vmap, we can make use of more efficient batching rules for
    # primitives where only some arguments are batched (e.g., for
    # lax_linalg.triangular_solve), and avoid instantiating large broadcasted
    # arrays.

    squeezed_args = []
    rev_filled_shapes = []

    for arg, core_dims in zip(args, input_core_dims):
      noncore_shape = arg.shape[:arg.ndim - len(core_dims)]

      pad_ndim = len(broadcast_shape) - len(noncore_shape)
      filled_shape = pad_ndim * (1,) + noncore_shape
      assert len(filled_shape) == len(broadcast_shape)
      assert jnp.broadcast_shapes(broadcast_shape, filled_shape) == broadcast_shape
      rev_filled_shapes.append(filled_shape[::-1])

      squeeze_indices = tuple(i for i, size in enumerate(noncore_shape) if size == 1)
      squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
      squeezed_args.append(squeezed_arg)

    vectorized_func = checked_func
    dims_to_expand = []
    for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
      in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
      if all(axis is None for axis in in_axes):
        dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
      else:
        vectorized_func = api.vmap(vectorized_func, in_axes)
    result = vectorized_func(*squeezed_args)
    return jnp.expand_dims(result, axis=dims_to_expand)

  return wrapped

import functools
import numpy as np
from jax import numpy as jnp

p = 60
n = 3000

rng = np.random.default_rng(202303021320)

@functools.partial(jnp.vectorize, signature='(p),(p)->()')
def func1(x, y):
    return x @ y

@functools.partial(modified_vectorize, signature='(p),(p)->()')
def func2(x, y):
    return x @ y

x = rng.standard_normal((n, p))
args = (x[None, :, :], x[:, None, :])
out1 = func1(*args)
out2 = func2(*args)

np.testing.assert_allclose(out1, out2, atol=0, rtol=0)

Benchmark:

%timeit func1(*args)
%timeit func2(*args)
685 ms ± 9.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.02 ms ± 33.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jakevdp commented 1 year ago

Thanks - are you interesting in putting together a PR with your updated implementation?

Gattocrucco commented 1 year ago

Yup

jakevdp commented 1 year ago

Awesome - I'll look out for it. Thanks!

Gattocrucco commented 1 year ago

PR #14782