jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

scatter_add in float32 may be inaccurate when accumulating large numbers of entries #18440

Open jakevdp opened 11 months ago

jakevdp commented 11 months ago

Originally reported in #18393; here's an example:

In [1]: import jax

In [2]: vals = jax.random.uniform(jax.random.key(0), (100_000_000,))

In [3]: idx = jax.numpy.zeros(len(vals), dtype=int)

In [4]: jax.numpy.zeros(1).at[idx].add(vals)
Out[4]: Array([16777216.], dtype=float32)

In [5]: vals.sum()
Out[5]: Array(50004288., dtype=float32)

In terms of real-valued math, these two should match.

The mismatch comes from the fact that scatter_add adds elements in sequence, and so for large numbers of elements accumulating at the same index it can hit catastrophic floating point rounding errors (topping out at $2^{24}$, because $2^{24} + 1$ is not exactly representable in float32)

This affects all APIs that make use of scatter_add, including jnp.bincount, weighted jnp.histogram, and many others.

We should think about how to mitigate this issue. Some possible options:

  1. Make scatter_add accumulate at higher precision. Probably not feasible because float64 is usually not available.
  2. Sort the updates to scatter_add so that smaller entries are accumulated first. This would address the issue in some (but not all) cases, but could make distributed/batched code less efficient as it could require an all-to-all for the global sort.
  3. Depending on the dtype and the number of accumulated values, we could raise a warning if we think the result might be inaccurate. Unfortunately, it would be hard to do this without frequent false positives.

None of these fixes strikes me as a good route to fixing this issue in general; it might be that the best we can do is document the potential problem.

hrbigelow commented 11 months ago

Hi Jake,

I was thinking you could do this using a staging 'accu' tensor of the same shape as the destination, and then do the summation in nested fashion. That is, compute the scatter_add on chunks of vals and idx, then sum those chunks element-wise.

I would guess it could be implemented very efficiently at a low level, and would also shard across devices without any particular difficulty.

I tested it and a chunk size of 1 million works. The safe chunk size would depend on an assumption of how widely varying the exponent of the summands are. If they are more or less uniformly distributed on any scale, I think it would still work.

This jax scan-based pseudocode works, but only if I insert the print statement, which presumably prevents the compiler from optimizing out the accu buffer.

import jax

def staged_scatter_add(vals, idx, chunk_size):
    def scan_fn(carry, item):
        vals, idx = item
        accu = jax.numpy.zeros_like(carry).at[idx].add(vals)
        out = carry + accu
        # this print statement is needed (I guess to prevent 'accu' from being optimized out?)
        # jax.debug.print('accu: {}', accu)
        return out, 0

    num_chunks = len(vals) // chunk_size
    vals = vals.reshape(num_chunks, chunk_size)
    idx = idx.reshape(num_chunks, chunk_size)
    out, _ = jax.lax.scan(scan_fn, jax.numpy.zeros(1), (vals, idx)) 
    return out

vals = jax.random.uniform(jax.random.key(0), (100_000_000,))
idx = jax.numpy.zeros(len(vals), dtype=int)
staged = staged_scatter_add(vals, idx, 1_000_000)
print(vals.sum(), staged[0])
# 50004290.0 50004410.0
hawkinsp commented 11 months ago

I'll note that we have support for bucketing accumulation in segment_sum to solve a variant of this problem: https://github.com/google/jax/blob/03cae62c78349f4e8a0ede08c0955fdb32070052/jax/_src/ops/scatter.py#L224