Open j-towns opened 5 months ago
Hi, thanks for the report! I think this performance discrepancy comes from the fact that the batching rule for dynamic_slice
is implemented in terms of gather
, and gather
is much less performant on TPU than dynamic_slice
. Fixing that is likely difficult because gather
is a much more general operation than dynamic_slice
.
It might be worth raising this in the OpenXLA repository – it may be possible to make the TPU compiler recognize and optimize the batched-dynamic-slice case of gather.
I agree that XLA improving this kind of scatter performance seems like the right long term solution.
As a workaround that doesn't require XLA changes, would it be feasible to introduce your own DUS-like primitive here, with the batching rule that you prefer (unrolled loop plus stacking)?
We could also consider changing the DUS batching rule in JAX, but that would probably take more testing to get right in general.
Thanks for having a look. It sounds to me like this should be fixed on the XLA level, so I'll raise an issue there and close this for now.
I'll try my own DUS primitive with custom batching as @mattjj suggested and see if that improves things.
Sorry to bother you @mattjj and @jakevdp, just wondering whether there's any chance you can ping someone internally to ask them to have a look at https://github.com/openxla/xla/issues/12982, since it doesn't look as though anyone has yet. The Python loop workaround discussed above (as well as a bunch of other workarounds I've tried) is still way too slow in my use-case, and my project is bottlenecking badly on this atm.
Re-opened this since the XLA issue https://github.com/openxla/xla/issues/12982 doesn't appear to have received any attention, and this is severely affecting a project I'm working on.
My colleague @ivanipenburg also discovered that this doesn't appear to be an issue on GPU. The compiled HLO on a GPU Colab is below, and it has a scatter
as opposed to the while-loop containing dynamic-update-slice
s.
HloModule jit_dynamic_update_slice, is_scheduled=true, entry_computation_layout={(f32[100,32]{1,0}, f32[100,2]{1,0}, s32[100,1]{1,0})->f32[100,32]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="b7e39acee486b5054fd3a46daf8da60a"}
%region_0.21 (Arg_0.22.0: f32[], Arg_1.23.0: f32[]) -> f32[] {
ROOT %Arg_1.23.0 = f32[] parameter(1)
%Arg_0.22.0 = f32[] parameter(0)
}
%fused_scatter (param_0: f32[100,32], param_1.12: f32[100,2], param_2.12: s32[100,1], param_3.8: s32[2]) -> f32[100,32] {
%param_0 = f32[100,32]{1,0} parameter(0)
%constant_4_1 = s32[] constant(0)
%broadcast.4.1 = s32[100,2]{1,0} broadcast(s32[] %constant_4_1), dimensions={}
%iota.2.1 = s32[100,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/iota[dtype=int32 shape=(100, 1) dimension=0]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%param_2.12 = s32[100,1]{1,0} parameter(2)
%bitcast.45.9 = s32[100]{0} bitcast(s32[100,1]{1,0} %param_2.12)
%broadcast.6.1 = s32[100]{0} broadcast(s32[] %constant_4_1), dimensions={}
%compare.1.3 = pred[100]{0} compare(s32[100]{0} %bitcast.45.9, s32[100]{0} %broadcast.6.1), direction=LT, metadata={op_name="jit(dynamic_update_slice)/jit(main)/lt" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%constant_7_1 = s32[] constant(32)
%broadcast.7.1 = s32[100]{0} broadcast(s32[] %constant_7_1), dimensions={}
%add.1.3 = s32[100]{0} add(s32[100]{0} %bitcast.45.9, s32[100]{0} %broadcast.7.1), metadata={op_name="jit(dynamic_update_slice)/jit(main)/add" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%select.1.3 = s32[100]{0} select(pred[100]{0} %compare.1.3, s32[100]{0} %add.1.3, s32[100]{0} %bitcast.45.9), metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%bitcast.57.5 = s32[100,1]{1,0} bitcast(s32[100]{0} %select.1.3)
%concatenate.1.5 = s32[100,2]{1,0} concatenate(s32[100,1]{1,0} %iota.2.1, s32[100,1]{1,0} %bitcast.57.5), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%param_3.8 = s32[2]{0} parameter(3)
%broadcast.9.1 = s32[100,2]{1,0} broadcast(s32[2]{0} %param_3.8), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/broadcast_in_dim[shape=(100, 2) broadcast_dimensions=(1,)]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%clamp.1.3 = s32[100,2]{1,0} clamp(s32[100,2]{1,0} %broadcast.4.1, s32[100,2]{1,0} %concatenate.1.5, s32[100,2]{1,0} %broadcast.9.1), metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
%param_1.12 = f32[100,2]{1,0} parameter(1)
%bitcast.66.1 = f32[100,1,2]{2,0,1} bitcast(f32[100,2]{1,0} %param_1.12)
ROOT %scatter.24.1 = f32[100,32]{1,0} scatter(f32[100,32]{1,0} %param_0, s32[100,2]{1,0} %clamp.1.3, f32[100,1,2]{2,0,1} %bitcast.66.1), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=%region_0.21, metadata={op_name="jit(dynamic_update_slice)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
}
%wrapped_copy_computation (param_0.13: f32[100,32]) -> f32[100,32] {
%param_0.13 = f32[100,32]{1,0} parameter(0)
ROOT %copy.3 = f32[100,32]{1,0} copy(f32[100,32]{1,0} %param_0.13)
}
ENTRY %main.25 (Arg_0.1.0: f32[100,32], Arg_1.2.0: f32[100,2], Arg_2.3.0: s32[100,1]) -> f32[100,32] {
%constant_6_0 = s32[2]{0} constant({99, 30})
%Arg_2.3.0 = s32[100,1]{1,0} parameter(2)
%Arg_1.2.0 = f32[100,2]{1,0} parameter(1)
%Arg_0.1.0 = f32[100,32]{1,0} parameter(0)
%wrapped_copy = f32[100,32]{1,0} fusion(f32[100,32]{1,0} %Arg_0.1.0), kind=kLoop, calls=%wrapped_copy_computation
ROOT %input_scatter_fusion = f32[100,32]{1,0} fusion(f32[100,32]{1,0} %wrapped_copy, f32[100,2]{1,0} %Arg_1.2.0, s32[100,1]{1,0} %Arg_2.3.0, s32[2]{0} %constant_6_0), kind=kInput, calls=%fused_scatter, metadata={op_name="jit(dynamic_update_slice)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-2-1a7d87e36410>" source_line=20}
}
Description
I've found that vmapping
lax.dynamic_update_slice
leads to surprising slow-downs. This issue affects the TPU backend. The following code benchmarksvmap(dynamic_update_slice)
vs. an equivalent Python loop (plusjnp.stack
):On a TPU v4-8 VM I get:
I spent some time digging into this is because it is significantly affecting the runtime of the project I'm working on...
I profiled
f_vmapped
by running the following, also on the v4-8 VM:The resulting log directory is here: dynamic-slice-profile.zip. Inspecting the trace, it seems almost all of the runtime is spent in a
while
loop:The while-loop appears too have 100 iterations (matching the size of the mapped axis), and, zooming in, each iteration looks like this:
I was surprised to see this, as I had assumed that
vmap(dynamic_update_slice)
would lower to ascatter
, not a while-loop withdynamic_update_slice
s in each iteration.To find out where the while loop came from I ran
make_jaxpr(f_vmapped)(operands, updates, starts)
, which results inAs expected, the
dynamic_update_slice
is transformed into ascatter
(with some other ops).Looking at the compilation, running
print(f_vmapped.lower(operands, updates, starts).as_text())
givesAgain there's a scatter op and no loop, as expected.
However, the compiled HLO does have a while loop. It's much longer so I've put it below. It looks like the while loop iterates over the mapped axis, and performs a
dynamic_update_slice
at each iteration. So somescatter
ops are compiling to a while loop withdynamic_update_slice
ops in the body. Presumably the compiler can see that the unrolled Python map version can be parallelized, whereas the while loop in the vmap version is preventing any parallelization? This could be why vmap is slower than the Python loop.I have already discussed this in #17844, which is a slow-down on GPU likely to be related to this one.
System info (python version, jaxlib version, accelerator, etc.)
Python version: 3.12.1., jaxlib version: 0.4.28, running on a TPU v4-8 VM.