intel / intel-extension-for-openxla

Apache License 2.0
39 stars 11 forks source link

mpi4jax API version mismatch #32

Open coreyjadams opened 6 months ago

coreyjadams commented 6 months ago

With release 0.3.0, I am unable to get mpi4jax to run. I am using this branch from an Intel-forked mpi4jax: https://github.com/jczaja/mpi4jax/tree/jczaja/xpu-support. This is running on Argonne's Sunspot cluster with Intel Max 1550 gpus.

I have installed intel_extension_for_open_xla with version 0.3.0 via pip. I have oneapi 2024.1 and agam 803.29. Here is what I see when I import jax, then import mpi4jax:

>>> import jax
jax.local_devices()
>>> jax.local_devices()
INFO: Intel Extension for OpenXLA version: 0.3.0, commit: 9a484818
Platform 'xpu' is experimental and not all JAX functionality may be correctly supported!
[xpu(id=0), xpu(id=1), xpu(id=2), xpu(id=3), xpu(id=4), xpu(id=5), xpu(id=6), xpu(id=7), xpu(id=8), xpu(id=9), xpu(id=10), xpu(id=11)]
>>> import mpi4jax
Registering b'mpi_allgather' and function <capsule object "xla._CUSTOM_CALL_TARGET" at 0x1458e2532ca0>
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/__init__.py", line 9, in <module>
    from ._src import (  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/__init__.py", line 11, in <module>
    from . import xla_bridge  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/xla_bridge/__init__.py", line 42, in <module>
    xla_client.register_custom_call_target(name, fn, platform="SYCL")
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 588, in register_custom_call_target
    _custom_callback_handler[xla_platform_name](name, fn, xla_platform_name)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: API version 1986225522 not supported for PJRT GPU plugin. Supported versions are 0 and 1.
>>> 

Do I need to target a specific api version in mpi4jax to make this work? Or, do I need to build JAX from source?

Thanks! Corey

Zantares commented 6 months ago

This is an API mismatch issue after upgrading JAX. @Dboyqiao has fixed this issue and will give the solution.

Dboyqiao commented 6 months ago

@coreyjadams We have fixed this issue on main branch (compatible with Jax 0.4.25), you need to build it manually before next release. Please refer to install-from-source-build for detail about source build. Besides, mpi4jax for xpu support has been merged, so you can use public repo: https://github.com/mpi4jax/mpi4jax.git directly now.

Dboyqiao commented 5 months ago

@coreyjadams As I know, the scale out with mpi4jax is still blocked by a JAX bug, which will be fixed by v0.4.28. Since intel-extension-for-openxla will not be rebased to align with JAX v0.4.28 soon, could you provide more detail about the JAX bug?

coreyjadams commented 5 months ago

Yes, here you go, it is this issue:

https://github.com/google/jax/issues/21160