Open kho opened 3 years ago
cc @froystig because we were looking to improve reductions in related ways
Checking back on this. It would really help projects on our side if a segment_logsumexp
that works with vmap and autodiff is available. I would be happy to help implement the necessary bits if a direction can be decided (e.g. do we want to add a new scatter primitive?).
Yes, we should surface a custom-monoid scatter, especially since it's in XLA. A new primitive probably makes sense.
Implementation-wise, it may help to look at generic reductions for an example of how to stage out the custom operation and set up the primitive, as well as for their AD and batching rules.
The various transformation rules don't all need to be implemented in one change. The first step would be to add a primitive with evaluation and type rules.
CI's green now. @froystig and @mattjj, can either of you take a look at PR #12004?
Dear all,
I'd also be interested in segment_logsumexp
. Currently I'm using the following code:
import jax
import jax.numpy as jnp
def segment_logsumexp(values: jnp.ndarray, indices: jnp.ndarray, num_segments: int) -> jnp.ndarray:
max_per_segment = jax.ops.segment_max(values, indices, num_segments)
adjusted_values = jnp.exp(values - max_per_segment[indices])
summed_exp_values = jax.ops.segment_sum(adjusted_values, indices, num_segments)
return jnp.log(summed_exp_values) + max_per_segment
which seems to work well enough for my applications, but is definitely less principled than the solution described above.
XLA's Scatter op allows a customizable reduction via a custom
update_computation
. Currently, different scatter reductions are implemented as separate primitives in lax (lax.scatter
, andlax.scatter_{add,mul,min,max}
). It would be useful for there to be a version of scatter that can be used with a custom reduction, as long as the reduction is a commutative monoid (e.g.logsumexp
). A concrete application that I have is to implementsegment_logsumexp
.It is not too hard to create a
scatter
jaxpr with a customupdate_jaxpr
field, e.g. the following seems to work forlogsumexp
:There are however a few remaining issues:
lax.scatter_p
primitive needs to know about the custom reduction too._reduction_jaxpr()
may return a jaxpr with constants, which will not work. E.g. replacinglogsumexp
above with the following leads to an error:Stack trace: