jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

Customizable reduction in jax.lax.scatter #6265

Open kho opened 3 years ago

kho commented 3 years ago

XLA's Scatter op allows a customizable reduction via a custom update_computation. Currently, different scatter reductions are implemented as separate primitives in lax (lax.scatter, and lax.scatter_{add,mul,min,max}). It would be useful for there to be a version of scatter that can be used with a custom reduction, as long as the reduction is a commutative monoid (e.g. logsumexp). A concrete application that I have is to implement segment_logsumexp.

It is not too hard to create a scatter jaxpr with a custom update_jaxpr field, e.g. the following seems to work for logsumexp:

import jax
import jax.numpy as jnp
from jax._src.lax.lax import _abstractify, _const, _reduction_jaxpr, scatter_p

logsumexp_0 = float('-inf')

def logsumexp(a, b):
  c = jnp.maximum(a, b)
  return jnp.where(
      jnp.logical_or(a == logsumexp_0, b == logsumexp_0), c,
      c + jnp.log(jnp.exp(a - c) + jnp.exp(b - c)))

def make_scatter(reduction):

  def scatter(operand,
              scatter_indices,
              updates,
              dimension_numbers,
              *,
              indices_are_sorted=False,
              unique_indices=False):
    jaxpr, consts = _reduction_jaxpr(
        reduction,
        # This doesn't have to the monoid identity. It's simply used to trace the call to `reduction`.
        _abstractify(_const(operand, 0)))
    return scatter_p.bind(
        operand,
        scatter_indices,
        updates,
        update_jaxpr=jaxpr,
        update_consts=consts,
        dimension_numbers=dimension_numbers,
        indices_are_sorted=indices_are_sorted,
        unique_indices=unique_indices)

  return scatter

operand = jnp.ones([3, 3])
scatter_indices = jnp.array([[1, 1], [2, 2]])
updates = jnp.array([1., 2.])
dimension_numbers = jax.lax.ScatterDimensionNumbers(
    update_window_dims=(),
    inserted_window_dims=(0, 1),
    scatter_dims_to_operand_dims=(0, 1))
make_scatter(logsumexp)(operand, scatter_indices, updates, dimension_numbers)
# Output:
# DeviceArray([[1.       , 1.       , 1.       ],
#             [1.       , 1.6931472, 1.       ],
#             [1.       , 1.       , 2.3132615]], dtype=float32)

There are however a few remaining issues:

  1. vmap doesn't work because the rule associated with the lax.scatter_p primitive needs to know about the custom reduction too.
  2. Auto-diff doesn't work either.
  3. _reduction_jaxpr() may return a jaxpr with constants, which will not work. E.g. replacing logsumexp above with the following leads to an error:
def logsumexp_bad(a, b):
   # This won't work because of the jnp.where line introduces a 0 constant.
  c = jnp.maximum(a, b)
  c = jnp.where(jax.lax.is_finite(c), c, 0)
  return c + jnp.log(jnp.exp(a - c) + jnp.exp(b - c))

operand = jnp.ones([3, 3])
scatter_indices = jnp.array([[1, 1], [2, 2]])
updates = jnp.array([1., 2.])
dimension_numbers = jax.lax.ScatterDimensionNumbers(
    update_window_dims=(),
    inserted_window_dims=(0, 1),
    scatter_dims_to_operand_dims=(0, 1))
make_scatter(logsumexp_bad)(operand, scatter_indices, updates,
                            dimension_numbers)

Stack trace:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-52-2a8b6bc91a46> in <module>()
      7     scatter_dims_to_operand_dims=(0, 1))
      8 make_scatter(logsumexp_bad)(operand, scatter_indices, updates,
----> 9                             dimension_numbers)

5 frames
<ipython-input-26-dc65a421cee8> in scatter(operand, scatter_indices, updates, dimension_numbers, indices_are_sorted, unique_indices)
     18         dimension_numbers=dimension_numbers,
     19         indices_are_sorted=indices_are_sorted,
---> 20         unique_indices=unique_indices)
     21 
     22   return scatter

google3/third_party/py/jax/core.py in bind(self, *args, **params)
    280     top_trace = find_top_trace(args)
    281     tracers = map(top_trace.full_raise, args)
--> 282     out = top_trace.process_primitive(self, tracers, params)
    283     return map(full_lower, out) if self.multiple_results else full_lower(out)
    284 

google3/third_party/py/jax/core.py in process_primitive(self, primitive, tracers, params)
    626 
    627   def process_primitive(self, primitive, tracers, params):
--> 628     return primitive.impl(*tracers, **params)
    629 
    630   def process_call(self, primitive, f, tracers, params):

google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    240 def apply_primitive(prim, *args, **params):
    241   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 242   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    243   return compiled_fun(*args)
    244 

google3/third_party/py/jax/_src/util.py in wrapper(*args, **kwargs)
    196         return f(*args, **kwargs)
    197       else:
--> 198         return cached(bool(config.x64_enabled), *args, **kwargs)
    199 
    200     wrapper.cache_clear = cached.cache_clear

google3/third_party/py/jax/interpreters/xla.py in __hash__(self)
   1268 
   1269   def __hash__(self):
-> 1270     raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
   1271 
   1272   setattr(device_array, "__hash__", __hash__)

TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
mattjj commented 3 years ago

cc @froystig because we were looking to improve reductions in related ways

kho commented 3 years ago

Checking back on this. It would really help projects on our side if a segment_logsumexp that works with vmap and autodiff is available. I would be happy to help implement the necessary bits if a direction can be decided (e.g. do we want to add a new scatter primitive?).

froystig commented 3 years ago

Yes, we should surface a custom-monoid scatter, especially since it's in XLA. A new primitive probably makes sense.

Implementation-wise, it may help to look at generic reductions for an example of how to stage out the custom operation and set up the primitive, as well as for their AD and batching rules.

The various transformation rules don't all need to be implemented in one change. The first step would be to add a primitive with evaluation and type rules.

kho commented 2 years ago

CI's green now. @froystig and @mattjj, can either of you take a look at PR #12004?

pawel-czyz commented 10 months ago

Dear all, I'd also be interested in segment_logsumexp. Currently I'm using the following code:

import jax
import jax.numpy as jnp

def segment_logsumexp(values: jnp.ndarray, indices: jnp.ndarray, num_segments: int) -> jnp.ndarray:
    max_per_segment = jax.ops.segment_max(values, indices, num_segments)
    adjusted_values = jnp.exp(values - max_per_segment[indices])
    summed_exp_values = jax.ops.segment_sum(adjusted_values, indices, num_segments)
    return jnp.log(summed_exp_values) + max_per_segment

which seems to work well enough for my applications, but is definitely less principled than the solution described above.