jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.29k stars 2.78k forks source link

Vmap of dynamic_update_slice very slow on TPU #21367

Open j-towns opened 5 months ago

j-towns commented 5 months ago

Description

I've found that vmapping lax.dynamic_update_slice leads to surprising slow-downs. This issue affects the TPU backend. The following code benchmarks vmap(dynamic_update_slice) vs. an equivalent Python loop (plus jnp.stack):

from timeit import timeit

from jax import jit, lax, vmap, make_jaxpr
import jax.numpy as jnp

# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))

operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')

f = lax.dynamic_update_slice

f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))

# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)

t_vmapped = timeit(
    lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100

t_pymapped = timeit(
    lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100

print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")

On a TPU v4-8 VM I get:

Time vmap(f): 0.00088s
Time pymap(f): 0.00036s

I spent some time digging into this is because it is significantly affecting the runtime of the project I'm working on...

I profiled f_vmapped by running the following, also on the v4-8 VM:

from jax import profiler 

with profiler.trace("dynamic-slice-profile"):
    f_vmapped(operands, updates, starts).block_until_ready()

The resulting log directory is here: dynamic-slice-profile.zip. Inspecting the trace, it seems almost all of the runtime is spent in a while loop:

Screenshot 2024-05-22 at 20 56 20

The while-loop appears too have 100 iterations (matching the size of the mapped axis), and, zooming in, each iteration looks like this:

Screenshot 2024-05-22 at 21 03 30

I was surprised to see this, as I had assumed that vmap(dynamic_update_slice) would lower to a scatter, not a while-loop with dynamic_update_slices in each iteration.

To find out where the while loop came from I ran make_jaxpr(f_vmapped)(operands, updates, starts), which results in

{ lambda ; a:f32[100,32] b:f32[100,2] c:i32[100,1]. let
    d:f32[100,32] = pjit[
      name=dynamic_update_slice
      jaxpr={ lambda ; e:f32[100,32] f:f32[100,2] g:i32[100,1]. let
          h:i32[100,1] = slice[
            limit_indices=(100, 1)
            start_indices=(0, 0)
            strides=(1, 1)
          ] g
          i:i32[100] = squeeze[dimensions=(1,)] h
          j:bool[100] = lt i 0
          k:i32[100] = add i 32
          l:i32[100] = select_n j i k
          m:i32[100,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(100, 1)
          ] l
          n:i32[100,1] = iota[dimension=0 dtype=int32 shape=(100, 1)] 
          o:i32[100,2] = concatenate[dimension=1] n m
          p:f32[100,32] = scatter[
            dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1))
            indices_are_sorted=True
            mode=GatherScatterMode.CLIP
            unique_indices=True
            update_consts=()
            update_jaxpr=None
          ] e o f

        in (p,) }
    ] a b c
  in (d,) }

As expected, the dynamic_update_slice is transformed into a scatter (with some other ops).

Looking at the compilation, running print(f_vmapped.lower(operands, updates, starts).as_text()) gives

module @jit_dynamic_update_slice attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<100x32xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<100x2xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<100x1xi32> {mhlo.layout_mode = "default"}) -> (tensor<100x32xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.slice %arg2 [0:100, 0:1] : (tensor<100x1xi32>) -> tensor<100x1xi32>
    %1 = stablehlo.reshape %0 : (tensor<100x1xi32>) -> tensor<100xi32>
    %c = stablehlo.constant dense<0> : tensor<i32>
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<100xi32>
    %3 = stablehlo.compare  LT, %1, %2,  SIGNED : (tensor<100xi32>, tensor<100xi32>) -> tensor<100xi1>
    %c_0 = stablehlo.constant dense<32> : tensor<i32>
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<100xi32>
    %5 = stablehlo.add %1, %4 : tensor<100xi32>
    %6 = stablehlo.select %3, %5, %1 : tensor<100xi1>, tensor<100xi32>
    %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<100xi32>) -> tensor<100x1xi32>
    %8 = stablehlo.iota dim = 0 : tensor<100x1xi32>
    %9 = stablehlo.concatenate %8, %7, dim = 1 : (tensor<100x1xi32>, tensor<100x1xi32>) -> tensor<100x2xi32>
    %c_1 = stablehlo.constant dense<99> : tensor<i32>
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %c_2 = stablehlo.constant dense<30> : tensor<i32>
    %11 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %12 = stablehlo.concatenate %10, %11, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %c_3 = stablehlo.constant dense<2147483647> : tensor<ui32>
    %13 = stablehlo.convert %c_3 : (tensor<ui32>) -> tensor<i32>
    %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor<i32>) -> tensor<2xi32>
    %15 = stablehlo.minimum %12, %14 : tensor<2xi32>
    %16 = stablehlo.broadcast_in_dim %15, dims = [1] : (tensor<2xi32>) -> tensor<100x2xi32>
    %c_4 = stablehlo.constant dense<0> : tensor<i32>
    %17 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<i32>) -> tensor<100x2xi32>
    %18 = stablehlo.clamp %17, %9, %16 : tensor<100x2xi32>
    %19 = "stablehlo.scatter"(%arg0, %18, %arg1) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
      stablehlo.return %arg4 : tensor<f32>
    }) : (tensor<100x32xf32>, tensor<100x2xi32>, tensor<100x2xf32>) -> tensor<100x32xf32>
    return %19 : tensor<100x32xf32>
  }
}

Again there's a scatter op and no loop, as expected.

However, the compiled HLO does have a while loop. It's much longer so I've put it below. It looks like the while loop iterates over the mapped axis, and performs a dynamic_update_slice at each iteration. So some scatter ops are compiling to a while loop with dynamic_update_slice ops in the body. Presumably the compiler can see that the unrolled Python map version can be parallelized, whereas the while loop in the vmap version is preventing any parallelization? This could be why vmap is slower than the Python loop.

I have already discussed this in #17844, which is a slow-down on GPU likely to be related to this one.

>>> print(f_vmapped.lower(operands, updates, starts).compile().as_text())
HloModule jit_dynamic_update_slice, is_scheduled=true, entry_computation_layout={(f32[100,32]{0,1:T(8,128)}, f32[100,2]{0,1:T(2,128)}, s32[100,1]{0,1:T(1,128)})->f32[100,32]{0,1:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.3: s32[100,1]) -> s32[100,1] {
  %param_0.3 = s32[100,1]{0,1:T(1,128)S(3)} parameter(0)
  %constant.24 = s32[]{:T(128)} constant(0)
  %broadcast.20 = s32[100,1]{0,1:T(1,128)} broadcast(s32[]{:T(128)} %constant.24), dimensions={}
  %compare.4 = pred[100,1]{0,1:T(4,128)(4,1)} compare(s32[100,1]{0,1:T(1,128)S(3)} %param_0.3, s32[100,1]{0,1:T(1,128)} %broadcast.20), direction=LT, metadata={op_name="jit(dynamic_update_slice)/jit(main)/lt" source_file="/home/jamietownsend/issue.py" source_line=21}
  %constant.23 = s32[]{:T(128)} constant(32)
  %broadcast.18 = s32[100,1]{0,1:T(1,128)} broadcast(s32[]{:T(128)} %constant.23), dimensions={}
  %add.6 = s32[100,1]{0,1:T(1,128)} add(s32[100,1]{0,1:T(1,128)S(3)} %param_0.3, s32[100,1]{0,1:T(1,128)} %broadcast.18), metadata={op_name="jit(dynamic_update_slice)/jit(main)/add" source_file="/home/jamietownsend/issue.py" source_line=21}
  ROOT %select.2 = s32[100,1]{0,1:T(1,128)S(3)} select(pred[100,1]{0,1:T(4,128)(4,1)} %compare.4, s32[100,1]{0,1:T(1,128)} %add.6, s32[100,1]{0,1:T(1,128)S(3)} %param_0.3), metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="/home/jamietownsend/issue.py" source_line=21}
}

%fused_computation.1 (param_0.4: s32[100,1], param_1.9: s32[100,1], param_2.10: s32[2]) -> s32[100,2] {
  %constant.25 = s32[]{:T(128)} constant(0)
  %broadcast.17 = s32[100,2]{0,1:T(2,128)} broadcast(s32[]{:T(128)} %constant.25), dimensions={}
  %param_1.9 = s32[100,1]{0,1:T(1,128)S(3)} parameter(1)
  %pad.7 = s32[100,2]{0,1:T(2,128)} pad(s32[100,1]{0,1:T(1,128)S(3)} %param_1.9, s32[]{:T(128)} %constant.25), padding=0_0x0_1, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %param_0.4 = s32[100,1]{0,1:T(1,128)S(3)} parameter(0)
  %pad.6 = s32[100,2]{0,1:T(2,128)} pad(s32[100,1]{0,1:T(1,128)S(3)} %param_0.4, s32[]{:T(128)} %constant.25), padding=0_0x1_0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %add.5 = s32[100,2]{0,1:T(2,128)} add(s32[100,2]{0,1:T(2,128)} %pad.7, s32[100,2]{0,1:T(2,128)} %pad.6), metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %param_2.10 = s32[2]{0:T(128)S(3)} parameter(2)
  %broadcast.16 = s32[100,2]{0,1:T(2,128)} broadcast(s32[2]{0:T(128)S(3)} %param_2.10), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/broadcast_in_dim[shape=(100, 2) broadcast_dimensions=(1,)]" source_file="/home/jamietownsend/issue.py" source_line=21}
  ROOT %clamp.0 = s32[100,2]{0,1:T(2,128)S(3)} clamp(s32[100,2]{0,1:T(2,128)} %broadcast.17, s32[100,2]{0,1:T(2,128)} %add.5, s32[100,2]{0,1:T(2,128)} %broadcast.16), metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="/home/jamietownsend/issue.py" source_line=21}
}

%dynamic-slice.reduce_sub_computation (lhs.1: s32[], rhs.1: s32[]) -> s32[] {
  %rhs.1 = s32[] parameter(1)
  %lhs.1 = s32[] parameter(0)
  ROOT %add.2 = s32[] add(s32[] %lhs.1, s32[] %rhs.1)
}

%fused_computation.5.clone (param_0.15: s32[100,2], param_1.23: s32[]) -> (s32[2], s32[1,2]) {
  %param_0.15 = s32[100,2]{0,1:T(2,128)S(3)} parameter(0)
  %param_1.23 = s32[]{:T(128)} parameter(1)
  %constant.30.clone.2 = s32[]{:T(128)} constant(0)
  %dynamic-slice.8.clone.2 = s32[1,2]{0,1:T(2,128)S(3)} dynamic-slice(s32[100,2]{0,1:T(2,128)S(3)} %param_0.15, s32[]{:T(128)} %param_1.23, s32[]{:T(128)} %constant.30.clone.2), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  %constant.38 = s32[] constant(0)
  %reduce.4 = s32[2]{0:T(128)S(3)} reduce(s32[1,2]{0,1:T(2,128)S(3)} %dynamic-slice.8.clone.2, s32[] %constant.38), dimensions={0}, to_apply=%dynamic-slice.reduce_sub_computation
  ROOT %tuple.8 = (s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) tuple(s32[2]{0:T(128)S(3)} %reduce.4, s32[1,2]{0,1:T(2,128)S(3)} %dynamic-slice.8.clone.2)
}

%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
  %rhs = pred[]{:T(512)} parameter(1)
  %lhs = pred[]{:T(512)} parameter(0)
  ROOT %and = pred[]{:T(512)} and(pred[]{:T(512)} %lhs, pred[]{:T(512)} %rhs)
}

%fused_computation.2.clone (param_0.16: s32[2], param_1.24: s32[1], param_2.24: s32[1]) -> pred[] {
  %constant.40 = s32[]{:T(128)} constant(0)
  %broadcast.23 = s32[2]{0:T(128)} broadcast(s32[]{:T(128)} %constant.40), dimensions={}
  %param_2.24 = s32[1]{0:T(128)} parameter(2)
  %pad.11 = s32[2]{0:T(128)} pad(s32[1]{0:T(128)} %param_2.24, s32[]{:T(128)} %constant.40), padding=0_1
  %param_1.24 = s32[1]{0:T(128)} parameter(1)
  %pad.10 = s32[2]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.24, s32[]{:T(128)} %constant.40), padding=1_0
  %add.9 = s32[2]{0:T(128)} add(s32[2]{0:T(128)} %pad.11, s32[2]{0:T(128)} %pad.10)
  %compare.9 = pred[2]{0:T(512)(128)(4,1)} compare(s32[2]{0:T(128)} %broadcast.23, s32[2]{0:T(128)} %add.9), direction=LE
  %param_0.16 = s32[2]{0:T(128)S(3)} parameter(0)
  %compare.8 = pred[2]{0:T(512)(128)(4,1)} compare(s32[2]{0:T(128)S(3)} %param_0.16, s32[2]{0:T(128)} %add.9), direction=GE
  %and.3 = pred[2]{0:T(512)(128)(4,1)} and(pred[2]{0:T(512)(128)(4,1)} %compare.9, pred[2]{0:T(512)(128)(4,1)} %compare.8)
  %constant.41 = pred[]{:T(512)} constant(true)
  ROOT %reduce.5 = pred[]{:T(512)} reduce(pred[2]{0:T(512)(128)(4,1)} %and.3, pred[]{:T(512)} %constant.41), dimensions={0}, to_apply=%and.reduce_sub_computation
}

%fused_computation.3.clone (param_0.17: f32[100,2], param_1.25: s32[], param_2.25: f32[100,32], param_3.16: s32[], param_4.12: s32[], param_5.5: pred[]) -> f32[1,2] {
  %param_5.5 = pred[]{:T(512)} parameter(5)
  %broadcast.24 = pred[1,2]{0,1:T(4,128)(4,1)} broadcast(pred[]{:T(512)} %param_5.5), dimensions={}
  %param_0.17 = f32[100,2]{0,1:T(2,128)S(3)} parameter(0)
  %param_1.25 = s32[]{:T(128)} parameter(1)
  %constant.42 = s32[]{:T(128)} constant(0)
  %dynamic-slice.9 = f32[1,2]{0,1:T(2,128)} dynamic-slice(f32[100,2]{0,1:T(2,128)S(3)} %param_0.17, s32[]{:T(128)} %param_1.25, s32[]{:T(128)} %constant.42), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  %param_2.25 = f32[100,32]{0,1:T(8,128)S(3)} parameter(2)
  %param_3.16 = s32[]{:T(128)} parameter(3)
  %param_4.12 = s32[]{:T(128)} parameter(4)
  %dynamic-slice.10 = f32[1,2]{0,1:T(2,128)} dynamic-slice(f32[100,32]{0,1:T(8,128)S(3)} %param_2.25, s32[]{:T(128)} %param_3.16, s32[]{:T(128)} %param_4.12), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"0","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  ROOT %select.4 = f32[1,2]{0,1:T(2,128)S(3)} select(pred[1,2]{0,1:T(4,128)(4,1)} %broadcast.24, f32[1,2]{0,1:T(2,128)} %dynamic-slice.9, f32[1,2]{0,1:T(2,128)} %dynamic-slice.10)
}

%wide.while_body (wide.param.1: (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2])) -> (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2]) {
  %constant.33..sunk = s32[]{:T(128)} constant(1)
  %wide.param.1 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) parameter(0)
  %get-tuple-element.47 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=0
  %get-tuple-element.57 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=4
  %get-tuple-element.48 = f32[100,32]{0,1:T(8,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=1
  %get-tuple-element.55 = s32[100,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=2
  %get-tuple-element.56 = f32[100,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=3
  %get-tuple-element.58 = s32[2]{0:T(128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=5
  %fusion.7 = (s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) fusion(s32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.55, s32[]{:T(128)} %get-tuple-element.47), kind=kLoop, calls=%fused_computation.5.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"589","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"9216"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.32 = s32[2]{0:T(128)S(3)} get-tuple-element((s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) %fusion.7), index=0
  %slice.12 = s32[1]{0:T(128)} slice(s32[2]{0:T(128)S(3)} %get-tuple-element.32), slice={[1:2]}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"586","iteration_bounds":["1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.31 = s32[1,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) %fusion.7), index=1
  %slice.11 = s32[1,1]{0,1:T(1,128)} slice(s32[1,2]{0,1:T(2,128)S(3)} %get-tuple-element.31), slice={[0:1], [0:1]}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"587","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %bitcast.5 = s32[1]{0:T(128)} bitcast(s32[1,1]{0,1:T(1,128)} %slice.11)
  %fusion.8 = pred[]{:T(512)} fusion(s32[2]{0:T(128)S(3)} %get-tuple-element.58, s32[1]{0:T(128)} %slice.12, s32[1]{0:T(128)} %bitcast.5), kind=kLoop, calls=%fused_computation.2.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"590","iteration_bounds":["1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"4096"}],"retry_config":{"retry_count":"0"}}
  %bitcast.3 = s32[]{:T(128)} bitcast(s32[1,1]{0,1:T(1,128)} %slice.11)
  %bitcast.4 = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %slice.12)
  %fusion.9 = f32[1,2]{0,1:T(2,128)S(3)} fusion(f32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.56, s32[]{:T(128)} %get-tuple-element.47, f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.48, s32[]{:T(128)} %bitcast.3, s32[]{:T(128)} %bitcast.4, /*index=5*/pred[]{:T(512)} %fusion.8), kind=kLoop, calls=%fused_computation.3.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"598","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"17408"}],"retry_config":{"retry_count":"0"}}
  %add.8 = s32[]{:T(128)} add(s32[]{:T(128)} %get-tuple-element.47, s32[]{:T(128)} %constant.33..sunk)
  %dynamic-update-slice.1 = f32[100,32]{0,1:T(8,128)S(3)} dynamic-update-slice(f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.48, f32[1,2]{0,1:T(2,128)S(3)} %fusion.9, s32[]{:T(128)} %bitcast.3, s32[]{:T(128)} %bitcast.4), backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"0","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"40960"}],"retry_config":{"retry_count":"0"}}
  ROOT %tuple.14 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) tuple(s32[]{:T(128)} %add.8, f32[100,32]{0,1:T(8,128)S(3)} %dynamic-update-slice.1, s32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.55, f32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.56, s32[]{:T(128)} %get-tuple-element.57, /*index=5*/s32[2]{0:T(128)S(3)} %get-tuple-element.58)
}

%wide.while_cond (wide.param.0: (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2])) -> pred[] {
  %constant.36 = s32[]{:T(128)} constant(100)
  %wide.param.0 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) parameter(0)
  %get-tuple-element.22 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.0), index=0
  ROOT %compare.7 = pred[]{:T(512)} compare(s32[]{:T(128)} %get-tuple-element.22, s32[]{:T(128)} %constant.36), direction=LT
}

ENTRY %main.25 (Arg_0.1: f32[100,32], Arg_1.2: f32[100,2], Arg_2.3: s32[100,1]) -> f32[100,32] {
  %constant.33 = s32[]{:T(128)} constant(1)
  %constant.4 = s32[]{:T(128)} constant(0)
  %constant.34 = s32[2]{0:T(128)} constant({99, 30})
  %copy-start.2 = (s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) copy-start(s32[2]{0:T(128)} %constant.34)
  %constant.6 = s32[2]{0:T(128)} constant({99, 30})
  %copy-start.4 = (s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) copy-start(s32[2]{0:T(128)} %constant.6)
  %Arg_2.3 = s32[100,1]{0,1:T(1,128)} parameter(2)
  %copy-start.3 = (s32[100,1]{0,1:T(1,128)S(3)}, s32[100,1]{0,1:T(1,128)}, u32[]{:S(2)}) copy-start(s32[100,1]{0,1:T(1,128)} %Arg_2.3)
  %Arg_1.2 = f32[100,2]{0,1:T(2,128)} parameter(1)
  %copy-start.1 = (f32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)}, u32[]{:S(2)}) copy-start(f32[100,2]{0,1:T(2,128)} %Arg_1.2)
  %Arg_0.1 = f32[100,32]{0,1:T(8,128)} parameter(0)
  %copy.9 = f32[100,32]{0,1:T(8,128)S(3)} copy(f32[100,32]{0,1:T(8,128)} %Arg_0.1), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["2","1"],"input_window_bounds":[],"estimated_cycles":"621","iteration_bounds":["2","1"]},"megacore_config":{"use_single_core":false,"core_id":"0","megacore_split_dim":"0","megacore_allreduce_bytes":"0"},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"16384"}],"retry_config":{"retry_count":"0"}}
  %copy.10 = s32[]{:T(128)} copy(s32[]{:T(128)} %constant.4)
  %copy-done.3 = s32[100,1]{0,1:T(1,128)S(3)} copy-done((s32[100,1]{0,1:T(1,128)S(3)}, s32[100,1]{0,1:T(1,128)}, u32[]{:S(2)}) %copy-start.3)
  %fusion = s32[100,1]{0,1:T(1,128)S(3)} fusion(s32[100,1]{0,1:T(1,128)S(3)} %copy-done.3), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"589","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %iota = s32[100,1]{0,1:T(1,128)S(3)} iota(), iota_dimension=0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/iota[dtype=int32 shape=(100, 1) dimension=0]" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"592","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %copy-done.4 = s32[2]{0:T(128)S(3)} copy-done((s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) %copy-start.4)
  %fusion.1 = s32[100,2]{0,1:T(2,128)S(3)} fusion(s32[100,1]{0,1:T(1,128)S(3)} %fusion, s32[100,1]{0,1:T(1,128)S(3)} %iota, s32[2]{0:T(128)S(3)} %copy-done.4), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"609","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2048"}],"retry_config":{"retry_count":"0"}}
  %copy-done.2 = s32[2]{0:T(128)S(3)} copy-done((s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) %copy-start.2)
  %copy-done.1 = f32[100,2]{0,1:T(2,128)S(3)} copy-done((f32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)}, u32[]{:S(2)}) %copy-start.1)
  %tuple.17 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) tuple(s32[]{:T(128)} %copy.10, f32[100,32]{0,1:T(8,128)S(3)} %copy.9, s32[100,2]{0,1:T(2,128)S(3)} %fusion.1, f32[100,2]{0,1:T(2,128)S(3)} %copy-done.1, s32[]{:T(128)} %constant.33, /*index=5*/s32[2]{0:T(128)S(3)} %copy-done.2)
  %while.1 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) while((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %tuple.17), condition=%wide.while_cond, body=%wide.while_body
  %get-tuple-element.59 = f32[100,32]{0,1:T(8,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %while.1), index=1
  %copy-start = (f32[100,32]{0,1:T(8,128)}, f32[100,32]{0,1:T(8,128)S(3)}, u32[]{:S(2)}) copy-start(f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.59)
  ROOT %copy-done = f32[100,32]{0,1:T(8,128)} copy-done((f32[100,32]{0,1:T(8,128)}, f32[100,32]{0,1:T(8,128)S(3)}, u32[]{:S(2)}) %copy-start)
}

System info (python version, jaxlib version, accelerator, etc.)

Python version: 3.12.1., jaxlib version: 0.4.28, running on a TPU v4-8 VM.

jakevdp commented 5 months ago

Hi, thanks for the report! I think this performance discrepancy comes from the fact that the batching rule for dynamic_slice is implemented in terms of gather, and gather is much less performant on TPU than dynamic_slice. Fixing that is likely difficult because gather is a much more general operation than dynamic_slice.

It might be worth raising this in the OpenXLA repository – it may be possible to make the TPU compiler recognize and optimize the batched-dynamic-slice case of gather.

mattjj commented 5 months ago

I agree that XLA improving this kind of scatter performance seems like the right long term solution.

As a workaround that doesn't require XLA changes, would it be feasible to introduce your own DUS-like primitive here, with the batching rule that you prefer (unrolled loop plus stacking)?

We could also consider changing the DUS batching rule in JAX, but that would probably take more testing to get right in general.

j-towns commented 5 months ago

Thanks for having a look. It sounds to me like this should be fixed on the XLA level, so I'll raise an issue there and close this for now.

I'll try my own DUS primitive with custom batching as @mattjj suggested and see if that improves things.

j-towns commented 3 months ago

Sorry to bother you @mattjj and @jakevdp, just wondering whether there's any chance you can ping someone internally to ask them to have a look at https://github.com/openxla/xla/issues/12982, since it doesn't look as though anyone has yet. The Python loop workaround discussed above (as well as a bunch of other workarounds I've tried) is still way too slow in my use-case, and my project is bottlenecking badly on this atm.

j-towns commented 2 months ago

Re-opened this since the XLA issue https://github.com/openxla/xla/issues/12982 doesn't appear to have received any attention, and this is severely affecting a project I'm working on.

My colleague @ivanipenburg also discovered that this doesn't appear to be an issue on GPU. The compiled HLO on a GPU Colab is below, and it has a scatter as opposed to the while-loop containing dynamic-update-slices.

HloModule jit_dynamic_update_slice, is_scheduled=true, entry_computation_layout={(f32[100,32]{1,0}, f32[100,2]{1,0}, s32[100,1]{1,0})->f32[100,32]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="b7e39acee486b5054fd3a46daf8da60a"}

%region_0.21 (Arg_0.22.0: f32[], Arg_1.23.0: f32[]) -> f32[] {
  ROOT %Arg_1.23.0 = f32[] parameter(1)
  %Arg_0.22.0 = f32[] parameter(0)
}

%fused_scatter (param_0: f32[100,32], param_1.12: f32[100,2], param_2.12: s32[100,1], param_3.8: s32[2]) -> f32[100,32] {
  %param_0 = f32[100,32]{1,0} parameter(0)
  %constant_4_1 = s32[] constant(0)
  %broadcast.4.1 = s32[100,2]{1,0} broadcast(s32[] %constant_4_1), dimensions={}
  %iota.2.1 = s32[100,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/iota[dtype=int32 shape=(100, 1) dimension=0]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %param_2.12 = s32[100,1]{1,0} parameter(2)
  %bitcast.45.9 = s32[100]{0} bitcast(s32[100,1]{1,0} %param_2.12)
  %broadcast.6.1 = s32[100]{0} broadcast(s32[] %constant_4_1), dimensions={}
  %compare.1.3 = pred[100]{0} compare(s32[100]{0} %bitcast.45.9, s32[100]{0} %broadcast.6.1), direction=LT, metadata={op_name="jit(dynamic_update_slice)/jit(main)/lt" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %constant_7_1 = s32[] constant(32)
  %broadcast.7.1 = s32[100]{0} broadcast(s32[] %constant_7_1), dimensions={}
  %add.1.3 = s32[100]{0} add(s32[100]{0} %bitcast.45.9, s32[100]{0} %broadcast.7.1), metadata={op_name="jit(dynamic_update_slice)/jit(main)/add" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %select.1.3 = s32[100]{0} select(pred[100]{0} %compare.1.3, s32[100]{0} %add.1.3, s32[100]{0} %bitcast.45.9), metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %bitcast.57.5 = s32[100,1]{1,0} bitcast(s32[100]{0} %select.1.3)
  %concatenate.1.5 = s32[100,2]{1,0} concatenate(s32[100,1]{1,0} %iota.2.1, s32[100,1]{1,0} %bitcast.57.5), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %param_3.8 = s32[2]{0} parameter(3)
  %broadcast.9.1 = s32[100,2]{1,0} broadcast(s32[2]{0} %param_3.8), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/broadcast_in_dim[shape=(100, 2) broadcast_dimensions=(1,)]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %clamp.1.3 = s32[100,2]{1,0} clamp(s32[100,2]{1,0} %broadcast.4.1, s32[100,2]{1,0} %concatenate.1.5, s32[100,2]{1,0} %broadcast.9.1), metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
  %param_1.12 = f32[100,2]{1,0} parameter(1)
  %bitcast.66.1 = f32[100,1,2]{2,0,1} bitcast(f32[100,2]{1,0} %param_1.12)
  ROOT %scatter.24.1 = f32[100,32]{1,0} scatter(f32[100,32]{1,0} %param_0, s32[100,2]{1,0} %clamp.1.3, f32[100,1,2]{2,0,1} %bitcast.66.1), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=%region_0.21, metadata={op_name="jit(dynamic_update_slice)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
}

%wrapped_copy_computation (param_0.13: f32[100,32]) -> f32[100,32] {
  %param_0.13 = f32[100,32]{1,0} parameter(0)
  ROOT %copy.3 = f32[100,32]{1,0} copy(f32[100,32]{1,0} %param_0.13)
}

ENTRY %main.25 (Arg_0.1.0: f32[100,32], Arg_1.2.0: f32[100,2], Arg_2.3.0: s32[100,1]) -> f32[100,32] {
  %constant_6_0 = s32[2]{0} constant({99, 30})
  %Arg_2.3.0 = s32[100,1]{1,0} parameter(2)
  %Arg_1.2.0 = f32[100,2]{1,0} parameter(1)
  %Arg_0.1.0 = f32[100,32]{1,0} parameter(0)
  %wrapped_copy = f32[100,32]{1,0} fusion(f32[100,32]{1,0} %Arg_0.1.0), kind=kLoop, calls=%wrapped_copy_computation
  ROOT %input_scatter_fusion = f32[100,32]{1,0} fusion(f32[100,32]{1,0} %wrapped_copy, f32[100,2]{1,0} %Arg_1.2.0, s32[100,1]{1,0} %Arg_2.3.0, s32[2]{0} %constant_6_0), kind=kInput, calls=%fused_scatter, metadata={op_name="jit(dynamic_update_slice)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
}