google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

Unable to initialize backend 'METAL' #21383

Open drbenvincent opened 1 month ago

drbenvincent commented 1 month ago

Description

I ran the Get Started code on the Apple Accelerated JAX training on Mac page, namely:

python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0

python -m pip install jax-metal

python -c 'import jax; print(jax.numpy.arange(10))'

On running that last line I get the following error:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

System info (python version, jaxlib version, accelerator, etc.)

Running import jax; jax.print_environment_info() returns the following error:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:874, in backends()
    873 try:
--> 874   backend = _init_backend(platform)
    875   _backends[platform] = backend

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:965, in _init_backend(platform)
    964 logger.debug("Initializing backend '%s'", platform)
--> 965 backend = registration.factory()
    966 # TODO(skye): consider raising more descriptive errors directly from backend
    967 # factories instead of returning None.

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:657, in register_plugin.<locals>.factory()
    656 if not xla_client.pjrt_plugin_initialized(plugin_name):
--> 657   xla_client.initialize_pjrt_plugin(plugin_name)
    658 updated_options = {}

File ~/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name)
    169 """Initializes a PJRT plugin.
    170
    171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
   (...)
    174   plugin_name: the name of the PJRT plugin.
    175 """
--> 176 _xla.initialize_pjrt_plugin(plugin_name)

XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In [2], line 1
----> 1 jax.print_environment_info()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/environment_info.py:45, in print_environment_info(return_string)
     43   python_version = sys.version.replace('\n', ' ')
     44   with np.printoptions(threshold=4, edgeitems=2):
---> 45     devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
     46   info = textwrap.dedent(
     47       f"""\
     48   jax:    {version.__version__}
   (...)
     55 """
     56   )
     57   nvidia_smi = try_nvidia_smi()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1077, in devices(backend)
   1052 def devices(
   1053     backend: str | xla_client.Client | None = None
   1054 ) -> list[xla_client.Device]:
   1055   """Returns a list of all devices for a given backend.
   1056
   1057   .. currentmodule:: jaxlib.xla_extension
   (...)
   1075     List of Device subclasses.
   1076   """
-> 1077   return get_backend(backend).devices()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1011, in get_backend(platform)
   1007 @lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.
   1008 def get_backend(
   1009     platform: None | str | xla_client.Client = None
   1010 ) -> xla_client.Client:
-> 1011   return _get_backend_uncached(platform)

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:990, in _get_backend_uncached(platform)
    986   return platform
    988 platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
--> 990 bs = backends()
    991 if platform is not None:
    992   platform = canonicalize_platform(platform)

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:890, in backends()
    888       else:
    889         err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 890       raise RuntimeError(err_msg)
    892 assert _default_backend is not None
    893 if not config.jax_platforms.value:

RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Running the command a second time results in:

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

More info

I get the same issue on both my M1 Max MacStudio and M1 2020 MacBook Air. Both running Sonoma 14.5

twiecki commented 1 month ago

Getting the same error on M1 Air.

twiecki commented 1 month ago

Duplicate of https://github.com/google/jax/issues/20148 ?

twiecki commented 1 month ago

pip install jax==0.4.26 jaxlib==0.4.26 gives:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716453894.367380  849760 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1716453894.452128  849760 service.cc:145] XLA service 0x600002350e00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716453894.452162  849760 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716453894.454461  849760 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716453894.454479  849760 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0
/AppleInternal/Library/BuildRoots/1dd9a6a2-74cf-11ee-8ed5-2a65a1af8551/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1097: failed assertion `Error importing MLIR bytecode.
'
Abort trap: 6

which is the error in https://github.com/google/jax/issues/20338.

drbenvincent commented 1 month ago

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

jax.print_environment_info() now gives:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

and running print(jax.numpy.arange(10)) in an ipython session gives

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716454190.044810 4298556 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1716454190.058568 4298556 service.cc:145] XLA service 0x600000588a00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716454190.058577 4298556 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716454190.059879 4298556 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716454190.059894 4298556 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
[0 1 2 3 4 5 6 7 8 9]
zhibor commented 1 month ago

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

this works for me on M1 MAX. thanks for sharing!

BeeGass commented 1 month ago

Im trying to install via poetry and I find that there is an issue where Jax-Metal will install a version of jax that cannot be overwritten by specifying an additonal jax dependency:

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.26", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }

[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

When installing this, despite the clear attempt at overriding the jax dependecy jax-metal has it set to 0.4.28

  • Installing jax (0.4.28)
  • Installing jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl)

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

In short, this isnt working for me.

for additional information. if I try

$ python3
>>> import jax
jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.

however if I dont try to override, that is with jax and jaxlib as version 0.4.27:

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }

[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Also take notice that it only updates the jaxlib and not the jax dependency

  • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)

and do what I tried before:

$ python3
>>> import jax
>>> jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/environment_info.py", line 45, in print_environment_info
    devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1077, in devices
    return get_backend(backend).devices()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

This same behavior occurs if I change the jax and jaxlib to version 0.4.28 or remove them entirely allowing jax-metal to install the correct jax and jaxlib versions.

machine and enviroment info: Chip: Apple M1 Pro MacOS: Sonoma 14.5 python version: 3.10.13

shuhand0 commented 1 month ago

"RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)". This error comes from jaxlib, which strictly checks the PJRT API version equality. jax-metal 0.0.7 adopts PJRT API from jaxlib-0.4.26. We have been communicated to JAX team and the solution is to set env var ENABLE_PJRT_COMPATIBILITY=1 if running jax-metal with jaxlib>0.4.26. The info can also be found in PYPI jax-metal page: https://pypi.org/project/jax-metal/.

yrahul3910 commented 1 month ago

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

I got mine (M1 Air, Sonoma) working by doing

python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal

after I used #19886 to set up my environment (but make sure that your shell is also arm64 before doing this, e.g. arch -arm64 zsh).

BeeGass commented 1 month ago

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

cleared the cache, all virtual environments and so forth, did a fresh install of all the dependencies ensuring that arm64 is the correct plarform.

$ python
>>> import platform
>>> platform.machine()
'arm64'

for the sake of showing thoroughness:

$ python3
>>> import platform
>>> platform.machine()
'arm64'

Also performed right before install of all dependencies:

arch -arm64 zsh

tried original test again

$ python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
    import jax.core as _core
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/core.py", line 18, in <module>
    from jax._src.core import (
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 39, in <module>
    from jax._src import dtypes
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dtypes.py", line 33, in <module>
    from jax._src import config
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
    version = check_jaxlib_version(
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 64, in check_jaxlib_version
    raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.

for the sake of showing thoroughness:

changed the jax and jaxlib dependency versions to 0.4.28 (I know the version of PJRT is within 0.4.26 but given that isnt working I hoped that perhaps the 0.4.28 or 0.4.27 version may have the PJRT version as well.)

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }

[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
  • Downgrading jaxlib (0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)

again did the following:

$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh

performed the test above

$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

also tried

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.28", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }

[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
  • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl)

again did the following:

$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh

performed the test above

$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

I have noticed that the people that have been able to get things working with this fix have been using the M3 chip. perhaps because im using the M1 chip, this could be the issue? Has anyone tried to replicate this on an M1?

just to be clear: Chip: Apple M1 Pro MacOS: Sonoma 14.5

shinhookang commented 1 month ago

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

I got mine (M1 Air, Sonoma) working by doing

python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal

after I used #19886 to set up my environment (but make sure that your shell is also arm64 before doing this, e.g. arch -arm64 zsh).

Thanks, this works for me. Mine is M3 Pro, Sonoma 14.5.

shuhand0 commented 1 month ago

@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?

BeeGass commented 1 month ago

@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?

Yeah still same behavior. Am told that the jax version needs to be equal to or higher than 0.4.27 @shuhand0

limyeeun1 commented 1 month ago

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

this works for me on M1 MAX. thanks for sharing!

this works for me on M3. Thank you!!!!!

Aiyaz3007 commented 1 month ago

follow this link, try based on your mac os versions, this works for me !