Open jakevdp opened 11 months ago
Hi Jake,
I was thinking you could do this using a staging 'accu' tensor of the same shape as the destination, and then do the summation in nested fashion. That is, compute the scatter_add on chunks of vals
and idx
, then sum those chunks element-wise.
I would guess it could be implemented very efficiently at a low level, and would also shard across devices without any particular difficulty.
I tested it and a chunk size of 1 million works. The safe chunk size would depend on an assumption of how widely varying the exponent of the summands are. If they are more or less uniformly distributed on any scale, I think it would still work.
This jax scan-based pseudocode works, but only if I insert the print statement, which presumably prevents the compiler from optimizing out the accu
buffer.
import jax
def staged_scatter_add(vals, idx, chunk_size):
def scan_fn(carry, item):
vals, idx = item
accu = jax.numpy.zeros_like(carry).at[idx].add(vals)
out = carry + accu
# this print statement is needed (I guess to prevent 'accu' from being optimized out?)
# jax.debug.print('accu: {}', accu)
return out, 0
num_chunks = len(vals) // chunk_size
vals = vals.reshape(num_chunks, chunk_size)
idx = idx.reshape(num_chunks, chunk_size)
out, _ = jax.lax.scan(scan_fn, jax.numpy.zeros(1), (vals, idx))
return out
vals = jax.random.uniform(jax.random.key(0), (100_000_000,))
idx = jax.numpy.zeros(len(vals), dtype=int)
staged = staged_scatter_add(vals, idx, 1_000_000)
print(vals.sum(), staged[0])
# 50004290.0 50004410.0
I'll note that we have support for bucketing accumulation in segment_sum
to solve a variant of this problem: https://github.com/google/jax/blob/03cae62c78349f4e8a0ede08c0955fdb32070052/jax/_src/ops/scatter.py#L224
Originally reported in #18393; here's an example:
In terms of real-valued math, these two should match.
The mismatch comes from the fact that
scatter_add
adds elements in sequence, and so for large numbers of elements accumulating at the same index it can hit catastrophic floating point rounding errors (topping out at $2^{24}$, because $2^{24} + 1$ is not exactly representable infloat32
)This affects all APIs that make use of
scatter_add
, includingjnp.bincount
, weightedjnp.histogram
, and many others.We should think about how to mitigate this issue. Some possible options:
scatter_add
accumulate at higher precision. Probably not feasible becausefloat64
is usually not available.scatter_add
so that smaller entries are accumulated first. This would address the issue in some (but not all) cases, but could make distributed/batched code less efficient as it could require an all-to-all for the global sort.None of these fixes strikes me as a good route to fixing this issue in general; it might be that the best we can do is document the potential problem.