Open BrunoKM opened 1 year ago
At least at the moment I think this is expected: deterministic scatters are much slower on GPU because they eliminate any parallelism. XLA would need to emit different code for a faster determistic scatter.
@hawkinsp thanks for a response, it absolutely makes sense that deterministic scatters should be slower.
Do you think it's expected that you should get an additional slow-down from vmap
? I.e. that this:
scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)
scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
>>> 17.4 s
is 2.7x times as slow as this:
scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit [scatter_add_jit(operand[i], updates[i], indices[i]).block_until_ready() for i in range(len(operand))]
>>> 4.6 s
with os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
enabled? I couldn't think of a reason this should be the case; both are deterministic, and I'd think compiling to xla should be at least as fast as doing the loop in python.
I encountered what I'm fairly confident is the same vmap-related slowdown on TPU, profiled it and discovered that while vmap of my function produces a scatter
in the jaxpr, that op lowers to a while
which loops over the mapped axis in XLA, with the body of the while loop containing a dynamic-update-slice
, not a scatter
. Presumably since XLA while
s cannot, in general, be parallelized, the compiler is unable to see this potential optimization. I don't know if this issue is common to the lowering any 'ragged' scatter.
@BrunoKM I will try your Python loop workaround and see whether that improves my use-case for now.
Setup:
In [1]: from jax import jit, lax, vmap, make_jaxpr
In [2]: import jax.numpy as jnp
In [3]: operand = jnp.ones((3, 4, 5))
In [4]: updates = jnp.ones((3, 2, 5))
In [5]: starts = jnp.ones((3,), dtype='int32')
In [6]: from functools import partial
In [7]: f = partial(lax.dynamic_update_slice_in_dim, axis=0)
Printing the jaxpr, note there is a single scatter op:
In [8]: make_jaxpr(vmap(f))(operand, updates, starts)
Out[8]:
{ lambda ; a:f32[3,4,5] b:f32[3,2,5] c:i32[3]. let
d:bool[3] = lt c 0
e:i32[3] = add c 4
f:i32[3] = select_n d c e
g:i32[] = add 0 5
h:i32[] = select_n False 0 g
i:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] f
j:i32[3,1] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 1)] h
k:i32[3,2] = concatenate[dimension=1] i j
l:i32[3,1] = iota[dimension=0 dtype=int32 shape=(3, 1)]
m:i32[3,3] = concatenate[dimension=1] l k
n:f32[3,4,5] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2))
indices_are_sorted=True
mode=GatherScatterMode.CLIP
unique_indices=True
update_consts=()
update_jaxpr=None
] a m b
in (n,) }
Printing the HLO, because it is long and hard to read I've surrounded the relevant parts in ########:
In [9]: print(jit(vmap(f)).lower(operand, updates, starts).compile().as_text())
HloModule jit__unnamed_function_, entry_computation_layout={(f32[3,4,5]{2,1,0}, f32[3,2,5]{2,1,0}, s32[3]{0})->f32[3,4,5]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
#############################################################
%fused_computation (param_0: f32[3,4,5], param_1.3: f32[3,2,5], param_2.7: s32[], param_3.8: pred[], param_4.8: s32[3,3]) -> f32[3,4,5] {
%param_0 = f32[3,4,5]{2,1,0} parameter(0)
%param_3.8 = pred[] parameter(3)
%broadcast.18 = pred[1,2,5]{2,1,0} broadcast(pred[] %param_3.8), dimensions={}
%param_1.3 = f32[3,2,5]{2,1,0} parameter(1)
%param_2.7 = s32[] parameter(2)
%constant.23 = s32[] constant(0)
%dynamic-slice.7 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,2,5]{2,1,0} %param_1.3, s32[] %param_2.7, s32[] %constant.23, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%param_4.8 = s32[3,3]{1,0} parameter(4)
%dynamic-slice.8 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_4.8, s32[] %param_2.7, s32[] %constant.23), dynamic_slice_sizes={1,3}
%slice.23 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.8), slice={[0:1], [0:1]}
%bitcast.6 = s32[] bitcast(s32[1,1]{1,0} %slice.23)
%bitcast.7 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.8)
%slice.22 = s32[1]{0} slice(s32[3]{0} %bitcast.7), slice={[1:2]}
%bitcast.5 = s32[] bitcast(s32[1]{0} %slice.22)
%dynamic-slice.6 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,4,5]{2,1,0} %param_0, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%select.1 = f32[1,2,5]{2,1,0} select(pred[1,2,5]{2,1,0} %broadcast.18, f32[1,2,5]{2,1,0} %dynamic-slice.7, f32[1,2,5]{2,1,0} %dynamic-slice.6)
###########################################################
# Note the shape of the update array, it is 1 in the batch dimension
ROOT %dynamic-update-slice.2 = f32[3,4,5]{2,1,0} dynamic-update-slice(f32[3,4,5]{2,1,0} %param_0, f32[1,2,5]{2,1,0} %select.1, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23)
###########################################################
}
#############################################################
%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
%lhs = pred[] parameter(0)
%rhs = pred[] parameter(1)
ROOT %and = pred[] and(pred[] %lhs, pred[] %rhs)
}
%fused_computation.1 (param_0.4: s32[3]) -> pred[] {
%constant.26 = s32[] constant(0)
%broadcast.19 = s32[3]{0} broadcast(s32[] %constant.26), dimensions={}
%param_0.4 = s32[3]{0} parameter(0)
%compare.4 = pred[3]{0} compare(s32[3]{0} %broadcast.19, s32[3]{0} %param_0.4), direction=LE
%constant.25 = s32[3]{0} constant({2, 2, 0})
%compare.3 = pred[3]{0} compare(s32[3]{0} %constant.25, s32[3]{0} %param_0.4), direction=GE
%and.2 = pred[3]{0} and(pred[3]{0} %compare.4, pred[3]{0} %compare.3)
%constant.24 = pred[] constant(true)
ROOT %reduce.1 = pred[] reduce(pred[3]{0} %and.2, pred[] %constant.24), dimensions={0}, to_apply=%and.reduce_sub_computation
}
%fused_computation.2 (param_0.7: s32[3,3], param_1.15: s32[]) -> s32[3] {
%param_0.7 = s32[3,3]{1,0} parameter(0)
%param_1.15 = s32[] parameter(1)
%constant.27 = s32[] constant(0)
%dynamic-slice.9 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_0.7, s32[] %param_1.15, s32[] %constant.27), dynamic_slice_sizes={1,3}
%slice.25 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.9), slice={[0:1], [0:1]}
%bitcast.9 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.25)
%bitcast.8 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.9)
%slice.24 = s32[2]{0} slice(s32[3]{0} %bitcast.8), slice={[1:3]}
ROOT %concatenate.2 = s32[3]{0} concatenate(s32[1]{0} %bitcast.9, s32[2]{0} %slice.24), dimensions={0}
}
#############################################################
%while_body (param.1: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> (s32[], f32[3,4,5], s32[3,3], f32[3,2,5]) {
%param.1 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=0
%copy.3 = s32[] copy(s32[] %get-tuple-element.12)
%constant.10 = s32[] constant(1)
%add = s32[] add(s32[] %copy.3, s32[] %constant.10)
%get-tuple-element.13 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=1
%get-tuple-element.19 = f32[3,2,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=3
%get-tuple-element.18 = s32[3,3]{1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=2
%fusion.2 = s32[3]{0} fusion(s32[3,3]{1,0} %get-tuple-element.18, s32[] %copy.3), kind=kLoop, calls=%fused_computation.2
%fusion.1 = pred[] fusion(s32[3]{0} %fusion.2), kind=kLoop, calls=%fused_computation.1
###########################################################
%fusion = f32[3,4,5]{2,1,0} fusion(f32[3,4,5]{2,1,0} %get-tuple-element.13, f32[3,2,5]{2,1,0} %get-tuple-element.19, s32[] %copy.3, pred[] %fusion.1, s32[3,3]{1,0} %get-tuple-element.18), kind=kLoop, calls=%fused_computation
###########################################################
ROOT %tuple.5 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %add, f32[3,4,5]{2,1,0} %fusion, s32[3,3]{1,0} %get-tuple-element.18, f32[3,2,5]{2,1,0} %get-tuple-element.19)
}
#############################################################
%while_cond (param.0: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> pred[] {
%param.0 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.0), index=0
%constant.1 = s32[] constant(3)
ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.1), direction=LT
}
%fused_computation.3 (param_0.10: s32[3]) -> s32[3,3] {
%constant.30 = s32[] constant(0)
%broadcast.25 = s32[3,3]{1,0} broadcast(s32[] %constant.30), dimensions={}
%iota.1 = s32[3,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(<unnamed function>)/jit(main)/iota[dtype=int32 shape=(3, 1) dimension=0]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%param_0.10 = s32[3]{0} parameter(0)
%broadcast.24 = s32[3]{0} broadcast(s32[] %constant.30), dimensions={}
%compare.5 = pred[3]{0} compare(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.24), direction=LT, metadata={op_name="jit(<unnamed function>)/jit(main)/lt" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.29 = s32[] constant(4)
%broadcast.23 = s32[3]{0} broadcast(s32[] %constant.29), dimensions={}
%add.1 = s32[3]{0} add(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.23), metadata={op_name="jit(<unnamed function>)/jit(main)/add" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%select.2 = s32[3]{0} select(pred[3]{0} %compare.5, s32[3]{0} %add.1, s32[3]{0} %param_0.10), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" sou
rce_file="<ipython-input-9-61375df8be79>" source_line=1}
%bitcast.10 = s32[3,1]{1,0} bitcast(s32[3]{0} %select.2), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%broadcast.22 = s32[3,1]{1,0} broadcast(s32[] %constant.30), dimensions={}
%concatenate.3 = s32[3,3]{1,0} concatenate(s32[3,1]{1,0} %iota.1, s32[3,1]{1,0} %bitcast.10, s32[3,1]{1,0} %broadcast.22), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.28 = s32[3]{0} constant({2, 2, 0})
%broadcast.21 = s32[3,3]{1,0} broadcast(s32[3]{0} %constant.28), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=(1,)]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
ROOT %clamp.0 = s32[3,3]{1,0} clamp(s32[3,3]{1,0} %broadcast.25, s32[3,3]{1,0} %concatenate.3, s32[3,3]{1,0} %broadcast.21), metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
ENTRY %main.26 (Arg_0.1: f32[3,4,5], Arg_1.2: f32[3,2,5], Arg_2.3: s32[3]) -> f32[3,4,5] {
%constant.4 = s32[] constant(0)
%copy.8 = s32[] copy(s32[] %constant.4)
%Arg_0.1 = f32[3,4,5]{2,1,0} parameter(0), sharding={replicated}
%copy.7 = f32[3,4,5]{2,1,0} copy(f32[3,4,5]{2,1,0} %Arg_0.1)
%Arg_2.3 = s32[3]{0} parameter(2), sharding={replicated}
%fusion.3 = s32[3,3]{1,0} fusion(s32[3]{0} %Arg_2.3), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%Arg_1.2 = f32[3,2,5]{2,1,0} parameter(1), sharding={replicated}
%tuple.3 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %copy.8, f32[3,4,5]{2,1,0} %copy.7, s32[3,3]{1,0} %fusion.3, f32[3,2,5]{2,1,0} %Arg_1.2)
###########################################################
%while = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) while((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %tuple.3), condition=%while_cond, body=%while_body, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
###########################################################
ROOT %get-tuple-element.5 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %while), index=1, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
Here's a more minimal, self-contained snippet which reproduces the vmap slowdown. On TPU the vmap version is around 2x slower than using a Python loop and jnp.stack
ing the result.
from timeit import timeit
from jax import jit, lax, vmap
import jax.numpy as jnp
# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))
operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')
f = lax.dynamic_update_slice
f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))
# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)
t_vmapped = timeit(
lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100
t_pymapped = timeit(
lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100
print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")
On a TPU v4-8 VM I get:
Time vmap(f): 0.00088s
Time pymap(f): 0.00036s
Running the script on CPU on my laptop, the Python loop version is slower than the vmap version
Time vmap(f): 1.3e-05s
Time pymap(f): 3.3e-05s
I realize this is an older issue, but one option is to roll your own deterministic scatter_add (using prefix sums):
def add_segment(iv, jt):
i, v = iv
j, t = jt
return j, v * jp.equal(i, j) + t
@jax.jit
def scatter_add_det(operand, updates, indices):
indices = jp.reshape(indices, updates.shape)
# Sort the indices and the values by the indices.
indices, sorted = jax.lax.sort_key_val(indices, updates, dimension=-1)
# Sum up runs of the same index - the sum for each index will be at the end of each run.
_, sums = jax.lax.associative_scan(add_segment, (indices, sorted))
# Produce an array of bools - if an element is set then the position
# is the end of run of the same index.
end_of_run = jp.concatenate([jp.not_equal(indices[1:], indices[:-1]), jp.array([True])])
# Set all position that are not at end of run to an out-of-bound index.
indices = jp.where(end_of_run, indices, operand.shape[-1])
# Now do scatter-add where we know the (in-bounds) indices are unique.
# That is still fast on GPUs (no non-determinism from atomics).
return operand.at[indices].add(sums, mode='drop', unique_indices=True)
This is 5-15x slower than the non-deterministic one (depending on shape of things), but at least it's not multiple orders of magnitude. It would be nice if XLA could lower to something like this automatically.
Description
The issue is 1) about a rather significant slow-down to the
scatter_add
operation when running jax with thexla_gpu_deterministic_ops=true
flag, and 2) about a further disproportionately large slow-down when usingvmap
around ascatter_add
operation.Below is the code to reproduce the issue. The timings are run with and without prepending
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
at the start of the script.Firstly, just a regular
scatter_add
benchmark:Secondly, a
scatter_add
benchmark withvmap
:It seems pretty unexpected that, when the
xla_gpu_deterministic_ops
flag is set to true, callingscatter_add
withvmap
with a batch-size of 100 makes the runtime 377x longer, i.e. 3.7 times slower than just using a manual pythonfor
-loop.Unrelatedly, although the slow-down of
scatter_add
is to be expected when enforcing determinism, it is rather severe (almost 2000x slower withoutvmap
, and over 200000x slower withvmap
). I guess this operation doesn't come up very regularly, but it appears, for example, in the backward pass through a bilinear interpolation of an image (e.g. when usingjax.scipy.ndimage.map_coordinates
). Even if thevmap
issue gets resolved, it would be absolutely fantastic if, in addition, there was some kind of warning about the potential impact on runtime that was shown when executing code with--xla_gpu_deterministic_ops=true
.What jax/jaxlib version are you using?
jax v0.4.16, jax v0.4.16+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
Linux, Ubuntu 22.04.3 LTS, Python 3.11.3
NVIDIA GPU info
Reproduced on a 3090 as well.