Open coreyjadams opened 6 months ago
This is an API mismatch issue after upgrading JAX. @Dboyqiao has fixed this issue and will give the solution.
@coreyjadams We have fixed this issue on main branch (compatible with Jax 0.4.25), you need to build it manually before next release. Please refer to install-from-source-build for detail about source build. Besides, mpi4jax for xpu support has been merged, so you can use public repo: https://github.com/mpi4jax/mpi4jax.git directly now.
@coreyjadams As I know, the scale out with mpi4jax is still blocked by a JAX bug, which will be fixed by v0.4.28. Since intel-extension-for-openxla will not be rebased to align with JAX v0.4.28 soon, could you provide more detail about the JAX bug?
Yes, here you go, it is this issue:
With release 0.3.0, I am unable to get mpi4jax to run. I am using this branch from an Intel-forked mpi4jax: https://github.com/jczaja/mpi4jax/tree/jczaja/xpu-support. This is running on Argonne's Sunspot cluster with Intel Max 1550 gpus.
I have installed intel_extension_for_open_xla with version 0.3.0 via pip. I have oneapi 2024.1 and agam 803.29. Here is what I see when I import jax, then import mpi4jax:
Do I need to target a specific api version in mpi4jax to make this work? Or, do I need to build JAX from source?
Thanks! Corey