jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.76k forks source link

`vmap` with `scatter_add` extremely slow when using `xla_gpu_deterministic_ops` #17844

Open BrunoKM opened 1 year ago

BrunoKM commented 1 year ago

Description

The issue is 1) about a rather significant slow-down to the scatter_add operation when running jax with the xla_gpu_deterministic_ops=true flag, and 2) about a further disproportionately large slow-down when using vmap around a scatter_add operation.

Below is the code to reproduce the issue. The timings are run with and without prepending os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" at the start of the script.

Firstly, just a regular scatter_add benchmark:

import jax
import jax.numpy as jnp

def scatter_add(
    operand,  # [operand_size]
    updates,  # [updates_size]
    indices,  # [updates_size, 1]
):
    # Define dimension numbers
    update_window_dims = tuple()
    inserted_window_dims = (0,)
    scatter_dims_to_operand_dims = (0,)
    res = jax.lax.scatter_add(
        operand,
        indices,
        updates,
        dimension_numbers=jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims),
        mode="drop",
    )
    return res

operand_size = 64 * 64  # e.g. a 64x64 image

operand = jnp.zeros((operand_size,))
updates = jnp.ones((operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(operand_size * 4, 1), minval=0, maxval=operand_size)

scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 25.3 µs ± 81 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 46.1 ms ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Secondly, a scatter_add benchmark with vmap:

n_batches = 100

operand = jnp.zeros((n_batches, operand_size,))
updates = jnp.ones((n_batches, operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(n_batches, operand_size * 4, 1), minval=0, maxval=operand_size)

scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)

scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 79.7 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 17.4 s ± 61.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It seems pretty unexpected that, when the xla_gpu_deterministic_ops flag is set to true, calling scatter_add with vmap with a batch-size of 100 makes the runtime 377x longer, i.e. 3.7 times slower than just using a manual python for-loop.

Unrelatedly, although the slow-down of scatter_add is to be expected when enforcing determinism, it is rather severe (almost 2000x slower without vmap, and over 200000x slower with vmap). I guess this operation doesn't come up very regularly, but it appears, for example, in the backward pass through a bilinear interpolation of an image (e.g. when using jax.scipy.ndimage.map_coordinates). Even if the vmap issue gets resolved, it would be absolutely fantastic if, in addition, there was some kind of warning about the potential impact on runtime that was shown when executing code with --xla_gpu_deterministic_ops=true.

What jax/jaxlib version are you using?

jax v0.4.16, jax v0.4.16+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info

Linux, Ubuntu 22.04.3 LTS, Python 3.11.3

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090         On | 00000000:01:00.0 Off |                  Off |
| 30%   43C    P2              139W / 500W|  20400MiB / 24564MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

Reproduced on a 3090 as well.

hawkinsp commented 1 year ago

At least at the moment I think this is expected: deterministic scatters are much slower on GPU because they eliminate any parallelism. XLA would need to emit different code for a faster determistic scatter.

BrunoKM commented 1 year ago

@hawkinsp thanks for a response, it absolutely makes sense that deterministic scatters should be slower.

Do you think it's expected that you should get an additional slow-down from vmap? I.e. that this:

scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)

scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
>>> 17.4 s

is 2.7x times as slow as this:

scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit [scatter_add_jit(operand[i], updates[i], indices[i]).block_until_ready() for i in range(len(operand))]
>>> 4.6 s

with os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" enabled? I couldn't think of a reason this should be the case; both are deterministic, and I'd think compiling to xla should be at least as fast as doing the loop in python.

j-towns commented 4 months ago

I encountered what I'm fairly confident is the same vmap-related slowdown on TPU, profiled it and discovered that while vmap of my function produces a scatter in the jaxpr, that op lowers to a while which loops over the mapped axis in XLA, with the body of the while loop containing a dynamic-update-slice, not a scatter. Presumably since XLA whiles cannot, in general, be parallelized, the compiler is unable to see this potential optimization. I don't know if this issue is common to the lowering any 'ragged' scatter.

@BrunoKM I will try your Python loop workaround and see whether that improves my use-case for now.

Setup:

In [1]: from jax import jit, lax, vmap, make_jaxpr
In [2]: import jax.numpy as jnp
In [3]: operand = jnp.ones((3, 4, 5))
In [4]: updates = jnp.ones((3, 2, 5))
In [5]: starts = jnp.ones((3,), dtype='int32')
In [6]: from functools import partial
In [7]: f = partial(lax.dynamic_update_slice_in_dim, axis=0)

Printing the jaxpr, note there is a single scatter op:

In [8]: make_jaxpr(vmap(f))(operand, updates, starts)
Out[8]: 
{ lambda ; a:f32[3,4,5] b:f32[3,2,5] c:i32[3]. let
    d:bool[3] = lt c 0
    e:i32[3] = add c 4
    f:i32[3] = select_n d c e
    g:i32[] = add 0 5
    h:i32[] = select_n False 0 g
    i:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] f
    j:i32[3,1] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 1)] h
    k:i32[3,2] = concatenate[dimension=1] i j
    l:i32[3,1] = iota[dimension=0 dtype=int32 shape=(3, 1)] 
    m:i32[3,3] = concatenate[dimension=1] l k
    n:f32[3,4,5] = scatter[
      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2))
      indices_are_sorted=True
      mode=GatherScatterMode.CLIP
      unique_indices=True
      update_consts=()
      update_jaxpr=None
    ] a m b
  in (n,) }

Printing the HLO, because it is long and hard to read I've surrounded the relevant parts in ########:

In [9]: print(jit(vmap(f)).lower(operand, updates, starts).compile().as_text())
HloModule jit__unnamed_function_, entry_computation_layout={(f32[3,4,5]{2,1,0}, f32[3,2,5]{2,1,0}, s32[3]{0})->f32[3,4,5]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

#############################################################
%fused_computation (param_0: f32[3,4,5], param_1.3: f32[3,2,5], param_2.7: s32[], param_3.8: pred[], param_4.8: s32[3,3]) -> f32[3,4,5] {
  %param_0 = f32[3,4,5]{2,1,0} parameter(0)
  %param_3.8 = pred[] parameter(3)
  %broadcast.18 = pred[1,2,5]{2,1,0} broadcast(pred[] %param_3.8), dimensions={}
  %param_1.3 = f32[3,2,5]{2,1,0} parameter(1)
  %param_2.7 = s32[] parameter(2)
  %constant.23 = s32[] constant(0)
  %dynamic-slice.7 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,2,5]{2,1,0} %param_1.3, s32[] %param_2.7, s32[] %constant.23, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
  %param_4.8 = s32[3,3]{1,0} parameter(4)
  %dynamic-slice.8 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_4.8, s32[] %param_2.7, s32[] %constant.23), dynamic_slice_sizes={1,3}
  %slice.23 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.8), slice={[0:1], [0:1]}
  %bitcast.6 = s32[] bitcast(s32[1,1]{1,0} %slice.23)
  %bitcast.7 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.8)
  %slice.22 = s32[1]{0} slice(s32[3]{0} %bitcast.7), slice={[1:2]}
  %bitcast.5 = s32[] bitcast(s32[1]{0} %slice.22)
  %dynamic-slice.6 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,4,5]{2,1,0} %param_0, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
  %select.1 = f32[1,2,5]{2,1,0} select(pred[1,2,5]{2,1,0} %broadcast.18, f32[1,2,5]{2,1,0} %dynamic-slice.7, f32[1,2,5]{2,1,0} %dynamic-slice.6)
  ###########################################################
  # Note the shape of the update array, it is 1 in the batch dimension
  ROOT %dynamic-update-slice.2 = f32[3,4,5]{2,1,0} dynamic-update-slice(f32[3,4,5]{2,1,0} %param_0, f32[1,2,5]{2,1,0} %select.1, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23)
  ###########################################################
}
#############################################################

%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
  %lhs = pred[] parameter(0)
  %rhs = pred[] parameter(1)
  ROOT %and = pred[] and(pred[] %lhs, pred[] %rhs)
}

%fused_computation.1 (param_0.4: s32[3]) -> pred[] {
  %constant.26 = s32[] constant(0)
  %broadcast.19 = s32[3]{0} broadcast(s32[] %constant.26), dimensions={}
  %param_0.4 = s32[3]{0} parameter(0)
  %compare.4 = pred[3]{0} compare(s32[3]{0} %broadcast.19, s32[3]{0} %param_0.4), direction=LE
  %constant.25 = s32[3]{0} constant({2, 2, 0})
  %compare.3 = pred[3]{0} compare(s32[3]{0} %constant.25, s32[3]{0} %param_0.4), direction=GE
  %and.2 = pred[3]{0} and(pred[3]{0} %compare.4, pred[3]{0} %compare.3)
  %constant.24 = pred[] constant(true)
  ROOT %reduce.1 = pred[] reduce(pred[3]{0} %and.2, pred[] %constant.24), dimensions={0}, to_apply=%and.reduce_sub_computation
}

%fused_computation.2 (param_0.7: s32[3,3], param_1.15: s32[]) -> s32[3] {
  %param_0.7 = s32[3,3]{1,0} parameter(0)
  %param_1.15 = s32[] parameter(1)
  %constant.27 = s32[] constant(0)
  %dynamic-slice.9 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_0.7, s32[] %param_1.15, s32[] %constant.27), dynamic_slice_sizes={1,3}
  %slice.25 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.9), slice={[0:1], [0:1]}
  %bitcast.9 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.25)
  %bitcast.8 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.9)
  %slice.24 = s32[2]{0} slice(s32[3]{0} %bitcast.8), slice={[1:3]}
  ROOT %concatenate.2 = s32[3]{0} concatenate(s32[1]{0} %bitcast.9, s32[2]{0} %slice.24), dimensions={0}
}

#############################################################
%while_body (param.1: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> (s32[], f32[3,4,5], s32[3,3], f32[3,2,5]) {
  %param.1 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
  %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=0
  %copy.3 = s32[] copy(s32[] %get-tuple-element.12)
  %constant.10 = s32[] constant(1)
  %add = s32[] add(s32[] %copy.3, s32[] %constant.10)
  %get-tuple-element.13 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=1
  %get-tuple-element.19 = f32[3,2,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=3
  %get-tuple-element.18 = s32[3,3]{1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=2
  %fusion.2 = s32[3]{0} fusion(s32[3,3]{1,0} %get-tuple-element.18, s32[] %copy.3), kind=kLoop, calls=%fused_computation.2
  %fusion.1 = pred[] fusion(s32[3]{0} %fusion.2), kind=kLoop, calls=%fused_computation.1
  ###########################################################
  %fusion = f32[3,4,5]{2,1,0} fusion(f32[3,4,5]{2,1,0} %get-tuple-element.13, f32[3,2,5]{2,1,0} %get-tuple-element.19, s32[] %copy.3, pred[] %fusion.1, s32[3,3]{1,0} %get-tuple-element.18), kind=kLoop, calls=%fused_computation
  ###########################################################
  ROOT %tuple.5 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %add, f32[3,4,5]{2,1,0} %fusion, s32[3,3]{1,0} %get-tuple-element.18, f32[3,2,5]{2,1,0} %get-tuple-element.19)
}
#############################################################

%while_cond (param.0: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> pred[] {
  %param.0 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
  %get-tuple-element = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.0), index=0
  %constant.1 = s32[] constant(3)
  ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.1), direction=LT
}

%fused_computation.3 (param_0.10: s32[3]) -> s32[3,3] {
  %constant.30 = s32[] constant(0)
  %broadcast.25 = s32[3,3]{1,0} broadcast(s32[] %constant.30), dimensions={}
  %iota.1 = s32[3,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(<unnamed function>)/jit(main)/iota[dtype=int32 shape=(3, 1) dimension=0]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %param_0.10 = s32[3]{0} parameter(0)
  %broadcast.24 = s32[3]{0} broadcast(s32[] %constant.30), dimensions={}
  %compare.5 = pred[3]{0} compare(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.24), direction=LT, metadata={op_name="jit(<unnamed function>)/jit(main)/lt" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %constant.29 = s32[] constant(4)
  %broadcast.23 = s32[3]{0} broadcast(s32[] %constant.29), dimensions={}
  %add.1 = s32[3]{0} add(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.23), metadata={op_name="jit(<unnamed function>)/jit(main)/add" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %select.2 = s32[3]{0} select(pred[3]{0} %compare.5, s32[3]{0} %add.1, s32[3]{0} %param_0.10), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" sou
rce_file="<ipython-input-9-61375df8be79>" source_line=1}
  %bitcast.10 = s32[3,1]{1,0} bitcast(s32[3]{0} %select.2), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %broadcast.22 = s32[3,1]{1,0} broadcast(s32[] %constant.30), dimensions={}
  %concatenate.3 = s32[3,3]{1,0} concatenate(s32[3,1]{1,0} %iota.1, s32[3,1]{1,0} %bitcast.10, s32[3,1]{1,0} %broadcast.22), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %constant.28 = s32[3]{0} constant({2, 2, 0})
  %broadcast.21 = s32[3,3]{1,0} broadcast(s32[3]{0} %constant.28), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=(1,)]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  ROOT %clamp.0 = s32[3,3]{1,0} clamp(s32[3,3]{1,0} %broadcast.25, s32[3,3]{1,0} %concatenate.3, s32[3,3]{1,0} %broadcast.21), metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}

#############################################################
ENTRY %main.26 (Arg_0.1: f32[3,4,5], Arg_1.2: f32[3,2,5], Arg_2.3: s32[3]) -> f32[3,4,5] {
  %constant.4 = s32[] constant(0)
  %copy.8 = s32[] copy(s32[] %constant.4)
  %Arg_0.1 = f32[3,4,5]{2,1,0} parameter(0), sharding={replicated}
  %copy.7 = f32[3,4,5]{2,1,0} copy(f32[3,4,5]{2,1,0} %Arg_0.1)
  %Arg_2.3 = s32[3]{0} parameter(2), sharding={replicated}
  %fusion.3 = s32[3,3]{1,0} fusion(s32[3]{0} %Arg_2.3), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %Arg_1.2 = f32[3,2,5]{2,1,0} parameter(1), sharding={replicated}
  %tuple.3 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %copy.8, f32[3,4,5]{2,1,0} %copy.7, s32[3,3]{1,0} %fusion.3, f32[3,2,5]{2,1,0} %Arg_1.2)
  ###########################################################
  %while = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) while((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %tuple.3), condition=%while_cond, body=%while_body, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  ###########################################################
  ROOT %get-tuple-element.5 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %while), index=1, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
j-towns commented 4 months ago

Here's a more minimal, self-contained snippet which reproduces the vmap slowdown. On TPU the vmap version is around 2x slower than using a Python loop and jnp.stacking the result.

from timeit import timeit

from jax import jit, lax, vmap
import jax.numpy as jnp

# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))

operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')

f = lax.dynamic_update_slice

f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))

# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)

t_vmapped = timeit(
    lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100

t_pymapped = timeit(
    lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100

print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")

On a TPU v4-8 VM I get:

Time vmap(f): 0.00088s
Time pymap(f): 0.00036s

Running the script on CPU on my laptop, the Python loop version is slower than the vmap version

Time vmap(f): 1.3e-05s
Time pymap(f): 3.3e-05s
jaro-sevcik commented 2 months ago

I realize this is an older issue, but one option is to roll your own deterministic scatter_add (using prefix sums):

def add_segment(iv, jt):
  i, v = iv
  j, t = jt
  return j, v * jp.equal(i, j) + t

@jax.jit
def scatter_add_det(operand, updates, indices):
  indices = jp.reshape(indices, updates.shape)
  # Sort the indices and the values by the indices.
  indices, sorted = jax.lax.sort_key_val(indices, updates, dimension=-1)
  # Sum up runs of the same index - the sum for each index will be at the end of each run.
  _, sums = jax.lax.associative_scan(add_segment, (indices, sorted))
  # Produce an array of bools - if an element is set then the position
  # is the end of run of the same index.
  end_of_run = jp.concatenate([jp.not_equal(indices[1:], indices[:-1]), jp.array([True])])
  # Set all position that are not at end of run to an out-of-bound index.
  indices = jp.where(end_of_run, indices, operand.shape[-1])
  # Now do scatter-add where we know the (in-bounds) indices are unique.
  # That is still fast on GPUs (no non-determinism from atomics).
  return operand.at[indices].add(sums, mode='drop', unique_indices=True)

This is 5-15x slower than the non-deterministic one (depending on shape of things), but at least it's not multiple orders of magnitude. It would be nice if XLA could lower to something like this automatically.