jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

Binary op compare with different element types in sharded, jitted function call #19691

Open Findus23 opened 9 months ago

Findus23 commented 9 months ago

Description

I have been struggling with this bug now for a very long time and while I fail to reproduce it on a trivial example, I still want to report it and provide as much detail as possible here. Maybe someone has an idea just based on the context and error message.

I am in the process of changing a large simulation codebase to make it work with multi-GPU and sharded data arrays. Inside I have a function that does something similar to distributing particle positions (2097152, 1, 3) to a mesh (256, 256, 256) in their own cell and neighbouring ones. This is a small part of a large simulation code that is wrapped in a jit. Without sharding this gives correct results. And removing the jit also makes the function work with sharding (but of course incredibly slow). Also calling just this part of the code jitted on sharded data works.

But the combination of sharded data and having the whole code-base jitted makes the compilation fail with the following error:

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: during context [hlo verifier]: Binary op compare with different element types: s64[] and s32[].  
       , for instruction %compare.162 = pred[] compare(s64[] %select.102, s32[] %multiply.67), direction=GE, metadata={op_name="jit(step_fct_with_args)/jit(main)/while/body/dynamic_update_slice" source_file="/home/fs72085/lwinkler/DISCO-DJ/src/discodj/core/painting.py" source_line=228}

I am struggling on how to debug this as in theory no 64bit object should show up in the whole simulation. (there is no jax_enable_x64). Also, I don't see any mention of Binary op compare with different element types in other jax/xla bug reports, so it doesn't seem like a common issue.

painting.py:228 contains a scan over another function, so I assume the compilation of the inner function fails.

return np.reshape(jax.lax.scan(loop_body,  
    mesh,  # initial carry  
    (np.reshape(positions, (n_chunks, -1, 1, dim)),  # xs[0]: positions  
    np.split(weight, n_chunks, axis=0) \  
       if weights_provided else np.ones(n_chunks, dtype=dtype),  # xs[1]: weights  
    np.arange(n_chunks)),  # xs[2]: chunk indices  
    )[1], -1)  # return ys, reshaped from n_chunk x chunk_size to npart_tot

Based on dynamic_update_slice and the fact this only happens when one specific part of the code is included, I am pretty sure the error boils down to this line of code inside the loop_body:

return carry, (carry[tuple(split)][..., 0] * kernel).sum(axis=-1)

Still none of the variables here are of a 64bit datatype, so I am really unsure what causes this (apart from possibly the usage of tuple, which I don't know how to avoid here)

dtype = np.int32
split = [np.zeros((65536, 8, 1), dtype=dtype)] * 3
╭────── split[0] ──────╮  
│ shape: (65536, 8, 1) │  
│ dtype: int32         │  
│ size: 2.0 MiB        │  
│ called in jit        │  
│ NamedSharding: P()   │  
╰──────────────────────╯  
╭────── kernel ──────╮  
│ shape: (65536, 8)  │  
│ dtype: float32     │  
│ size: 2.0 MiB      │  
│ called in jit      │  
│ NamedSharding: P() │  
╰────────────────────╯  
╭───────────────────── carry ─────────────────────╮  
│ shape: (256, 256, 256)                          │  
│ dtype: float32                                  │  
│ size: 64.0 MiB                                  │  
│ called in jit                                   │  
│ NamedSharding: P(None, 'gpus')                  │  
│ axis 1 is sharded: GPU 0 contains 0:32 (of 256) │  
╰─────────────────────────────────────────────────╯

Also as a slightly related question: Is there any way to read the compiler output when a hlo verifier error occurs? Because it would be interesting to see the lines before and after the part that fails. Even the output of XLA_FLAGS=--xla_dump_to= doesn't contain the line that is mentioned in the error message:

%compare.162 = pred[] compare(s64[] %select.102, s32[] %multiply.67)

Sorry that this is a bit vague, but I feel like this is might be a XLA compiler issue that only occurs in those specific circumstances, so I don't know how to break it down to a reproducible code I can share. But maybe someone has an idea how to fix this.

What jax/jaxlib version are you using?

0.4.23 (cuda11_pip) Also tested with jax==0.4.24 and jaxlib==0.4.24.dev20240206+cuda11.cudnn86 with identical output

Which accelerator(s) are you using?

GPU (8x NVIDIA A40 on 4 hosts)

Additional system info?

Python 3.11.3, Linux

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A40                     Off | 00000000:41:00.0 Off |                    0 |
|  0%   37C    P0              74W / 300W |      4MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A40                     Off | 00000000:A1:00.0 Off |                    0 |
|  0%   35C    P0              74W / 300W |      4MiB / 46068MiB |      4%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
jakevdp commented 9 months ago

Hi - sorry it's hard to say much here without a more complete reproduction. Given that the error comes from dynamic_update_slice, I'd like to see where that operation is coming from (the line you pasted, carry[tuple(split)][..., 0], will lower to gather, not to dynamic_update_slice).

jeffgortmaker commented 5 months ago

Hi @Findus23 and @jakevdp.

I ran into this same bug and was able to put together the following MRE. It's for CPUs because I don't have easy access to multiple GPUs, but maybe @Findus23 can check that the code errors-out on GPUs too?

This is with jax and jaxlib version 0.4.26. I've confirmed the error on Linux, Windows, and WSL. Also of note is that I only get the error when setting jax_enable_x64 to True. Not sure how to square that with the above error where it wasn't enabled.

# use two devices
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

# use 64-bit on CPU
import jax
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')

# configure a simple sharding
from jax.sharding import Mesh, PartitionSpec, NamedSharding
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))

def f(y):
    return y - jax.lax.map(g, y)

def g(y):
    return y

x = jax.numpy.ones(2)
print(f(x))  # [0. 0.]
print(jax.jit(f)(x))  # [0. 0.]
print(f(jax.device_put(x, sharding)))  # [0. 0.]
print(jax.jit(f)(jax.device_put(x, sharding)))  # error

The error seems the same as the above one.

Traceback (most recent call last):
  File "sharding_bug_mre.py", line 25, in <module>
    print(jax.jit(f)(jax.device_put(x, sharding)))  # error
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: during context [hlo verifier]: Binary op compare with different element types: s64[] and s32[].
    , for instruction %compare.3 = pred[] compare(s64[] %select.1, s32[] %multiply), direction=GE, metadata={op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="sharding_bug_mre.py" source_line=16}

Failed after spmd-partitioning
Findus23 commented 5 months ago

Hi,

Thanks to @jeffgortmaker for coming up with the reproducible issue (I tested it with 2 NVIDIA A40 and can also reproduce it there). Originally I was unsure if you had not just found another bug that by chance maybe causes the same error. (After all the example doesn't even contain integer data, while the part where things break in my code does some explicit integer dtype changes) But (take everything from here with a grain of salt as I don't have much time for testing right now) maybe I misunderstood the original issue and made some mistakes when debugging and explicitly setting dtypes everywhere.

It turns out that when I run my code with single precision everywhere and no jax_enable_x64, it succeeds without this issue. But when using jax_enable_x64 one would expect explicitly making all integers np.int64 would then also work, but it doesn't. So I have the theory that my bug is mostly unrelated to my actual code and indeed exactly the thing @jeffgortmaker discovered: That combining jax_enable_x64, jit, some map and sharding breaks inside the implicit resharding.

This would also explain the great comment by @jakevdp. My code doesn't do a dynamic_update_slice, but instead the sharding code created in jit does. And maybe also why the error message points to the lax.scan/lax.map line instead of the actual code.

So I guess the workaround until the jit-compiler-bug from the example is found, is just using single precision everywhere or explicitly setting the dtype of every single np array created in the code to np.int64/np.float64 instead of jax_enable_x64 (https://github.com/google/jax/issues/8178 style).

jeffgortmaker commented 5 months ago

Thanks @Findus23 for confirming on GPUs and checking on jax_enable_x64!

For what it's worth, in the above example adding lots of .astype(jnp.float64)'s didn't seem to eliminate the error for me. So for code needing 64-bit precision, I'm not sure if there's an unblocking workaround yet other than just not doing sharding in this way.

Findus23 commented 5 months ago

@jeffgortmaker I mistakenly thought one could still use explicit 64 types without ´jax_enable_x64` and that would just change the default. But as https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision explains, that doesn't work and therefore you are right and there doesn't seem to be any workaround for this issue (apart from not using sharding, not using 64bit anywhere or potentially not using scan/map this way)