Open Findus23 opened 9 months ago
Hi - sorry it's hard to say much here without a more complete reproduction. Given that the error comes from dynamic_update_slice
, I'd like to see where that operation is coming from (the line you pasted, carry[tuple(split)][..., 0]
, will lower to gather
, not to dynamic_update_slice
).
Hi @Findus23 and @jakevdp.
I ran into this same bug and was able to put together the following MRE. It's for CPUs because I don't have easy access to multiple GPUs, but maybe @Findus23 can check that the code errors-out on GPUs too?
This is with jax
and jaxlib
version 0.4.26. I've confirmed the error on Linux, Windows, and WSL. Also of note is that I only get the error when setting jax_enable_x64
to True
. Not sure how to square that with the above error where it wasn't enabled.
# use two devices
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
# use 64-bit on CPU
import jax
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
# configure a simple sharding
from jax.sharding import Mesh, PartitionSpec, NamedSharding
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))
def f(y):
return y - jax.lax.map(g, y)
def g(y):
return y
x = jax.numpy.ones(2)
print(f(x)) # [0. 0.]
print(jax.jit(f)(x)) # [0. 0.]
print(f(jax.device_put(x, sharding))) # [0. 0.]
print(jax.jit(f)(jax.device_put(x, sharding))) # error
The error seems the same as the above one.
Traceback (most recent call last):
File "sharding_bug_mre.py", line 25, in <module>
print(jax.jit(f)(jax.device_put(x, sharding))) # error
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: during context [hlo verifier]: Binary op compare with different element types: s64[] and s32[].
, for instruction %compare.3 = pred[] compare(s64[] %select.1, s32[] %multiply), direction=GE, metadata={op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="sharding_bug_mre.py" source_line=16}
Failed after spmd-partitioning
Hi,
Thanks to @jeffgortmaker for coming up with the reproducible issue (I tested it with 2 NVIDIA A40 and can also reproduce it there). Originally I was unsure if you had not just found another bug that by chance maybe causes the same error. (After all the example doesn't even contain integer data, while the part where things break in my code does some explicit integer dtype changes) But (take everything from here with a grain of salt as I don't have much time for testing right now) maybe I misunderstood the original issue and made some mistakes when debugging and explicitly setting dtypes everywhere.
It turns out that when I run my code with single precision everywhere and no jax_enable_x64
, it succeeds without this issue. But when using jax_enable_x64
one would expect explicitly making all integers np.int64 would then also work, but it doesn't. So I have the theory that my bug is mostly unrelated to my actual code and indeed exactly the thing @jeffgortmaker discovered: That combining jax_enable_x64
, jit
, some map and sharding breaks inside the implicit resharding.
This would also explain the great comment by @jakevdp. My code doesn't do a dynamic_update_slice
, but instead the sharding code created in jit does. And maybe also why the error message points to the lax.scan
/lax.map
line instead of the actual code.
So I guess the workaround until the jit-compiler-bug from the example is found, is just using single precision everywhere or explicitly setting the dtype of every single np array created in the code to np.int64/np.float64 instead of jax_enable_x64
(https://github.com/google/jax/issues/8178 style).
Thanks @Findus23 for confirming on GPUs and checking on jax_enable_x64
!
For what it's worth, in the above example adding lots of .astype(jnp.float64)
's didn't seem to eliminate the error for me. So for code needing 64-bit precision, I'm not sure if there's an unblocking workaround yet other than just not doing sharding in this way.
@jeffgortmaker I mistakenly thought one could still use explicit 64 types without ´jax_enable_x64` and that would just change the default. But as https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision explains, that doesn't work and therefore you are right and there doesn't seem to be any workaround for this issue (apart from not using sharding, not using 64bit anywhere or potentially not using scan/map this way)
Description
I have been struggling with this bug now for a very long time and while I fail to reproduce it on a trivial example, I still want to report it and provide as much detail as possible here. Maybe someone has an idea just based on the context and error message.
I am in the process of changing a large simulation codebase to make it work with multi-GPU and sharded data arrays. Inside I have a function that does something similar to distributing particle positions
(2097152, 1, 3)
to a mesh(256, 256, 256)
in their own cell and neighbouring ones. This is a small part of a large simulation code that is wrapped in a jit. Without sharding this gives correct results. And removing the jit also makes the function work with sharding (but of course incredibly slow). Also calling just this part of the code jitted on sharded data works.But the combination of sharded data and having the whole code-base jitted makes the compilation fail with the following error:
I am struggling on how to debug this as in theory no 64bit object should show up in the whole simulation. (there is no
jax_enable_x64
). Also, I don't see any mention ofBinary op compare with different element types
in other jax/xla bug reports, so it doesn't seem like a common issue.painting.py:228
contains a scan over another function, so I assume the compilation of the inner function fails.Based on
dynamic_update_slice
and the fact this only happens when one specific part of the code is included, I am pretty sure the error boils down to this line of code inside theloop_body
:Still none of the variables here are of a 64bit datatype, so I am really unsure what causes this (apart from possibly the usage of
tuple
, which I don't know how to avoid here)Also as a slightly related question: Is there any way to read the compiler output when a
hlo verifier
error occurs? Because it would be interesting to see the lines before and after the part that fails. Even the output ofXLA_FLAGS=--xla_dump_to=
doesn't contain the line that is mentioned in the error message:Sorry that this is a bit vague, but I feel like this is might be a XLA compiler issue that only occurs in those specific circumstances, so I don't know how to break it down to a reproducible code I can share. But maybe someone has an idea how to fix this.
What jax/jaxlib version are you using?
0.4.23 (cuda11_pip) Also tested with
jax==0.4.24
andjaxlib==0.4.24.dev20240206+cuda11.cudnn86
with identical outputWhich accelerator(s) are you using?
GPU (8x NVIDIA A40 on 4 hosts)
Additional system info?
Python 3.11.3, Linux
NVIDIA GPU info