Open drbenvincent opened 1 month ago
Getting the same error on M1 Air.
Duplicate of https://github.com/google/jax/issues/20148 ?
pip install jax==0.4.26 jaxlib==0.4.26
gives:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716453894.367380 849760 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1716453894.452128 849760 service.cc:145] XLA service 0x600002350e00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716453894.452162 849760 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716453894.454461 849760 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716453894.454479 849760 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0
/AppleInternal/Library/BuildRoots/1dd9a6a2-74cf-11ee-8ed5-2a65a1af8551/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1097: failed assertion `Error importing MLIR bytecode.
'
Abort trap: 6
which is the error in https://github.com/google/jax/issues/20338.
When I ran pip install jax==0.4.26 jaxlib==0.4.26
then I think I got success.
jax.print_environment_info()
now gives:
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
and running print(jax.numpy.arange(10))
in an ipython session gives
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716454190.044810 4298556 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0000 00:00:1716454190.058568 4298556 service.cc:145] XLA service 0x600000588a00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716454190.058577 4298556 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716454190.059879 4298556 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716454190.059894 4298556 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
[0 1 2 3 4 5 6 7 8 9]
When I ran
pip install jax==0.4.26 jaxlib==0.4.26
then I think I got success.
this works for me on M1 MAX. thanks for sharing!
Im trying to install via poetry and I find that there is an issue where Jax-Metal will install a version of jax that cannot be overwritten by specifying an additonal jax dependency:
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.26", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
When installing this, despite the clear attempt at overriding the jax
dependecy jax-metal
has it set to 0.4.28
• Installing jax (0.4.28)
• Installing jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl)
When I ran
pip install jax==0.4.26 jaxlib==0.4.26
then I think I got success.
In short, this isnt working for me.
for additional information. if I try
$ python3
>>> import jax
jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.
however if I dont try to override, that is with jax
and jaxlib
as version 0.4.27:
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Also take notice that it only updates the jaxlib
and not the jax
dependency
• Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)
and do what I tried before:
$ python3
>>> import jax
>>> jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/environment_info.py", line 45, in print_environment_info
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1077, in devices
return get_backend(backend).devices()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
return _get_backend_uncached(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
bs = backends()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
This same behavior occurs if I change the jax
and jaxlib
to version 0.4.28 or remove them entirely allowing jax-metal to install the correct jax
and jaxlib
versions.
machine and enviroment info: Chip: Apple M1 Pro MacOS: Sonoma 14.5 python version: 3.10.13
"RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)". This error comes from jaxlib, which strictly checks the PJRT API version equality. jax-metal 0.0.7 adopts PJRT API from jaxlib-0.4.26. We have been communicated to JAX team and the solution is to set env var ENABLE_PJRT_COMPATIBILITY=1 if running jax-metal with jaxlib>0.4.26. The info can also be found in PYPI jax-metal page: https://pypi.org/project/jax-metal/.
@BeeGass Are you sure your python3
is an arm64
binary? I've been bitten by this more times than I care for. Try
import platform
platform.machine() # should give you 'arm64'
I got mine (M1 Air, Sonoma) working by doing
python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal
after I used #19886 to set up my environment (but make sure that your shell is also arm64
before doing this, e.g. arch -arm64 zsh
).
@BeeGass Are you sure your
python3
is anarm64
binary? I've been bitten by this more times than I care for. Tryimport platform platform.machine() # should give you 'arm64'
cleared the cache, all virtual environments and so forth, did a fresh install of all the dependencies ensuring that arm64 is the correct plarform.
$ python
>>> import platform
>>> platform.machine()
'arm64'
for the sake of showing thoroughness:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
Also performed right before install of all dependencies:
arch -arm64 zsh
tried original test again
$ python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
import jax.core as _core
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/core.py", line 18, in <module>
from jax._src.core import (
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 39, in <module>
from jax._src import dtypes
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dtypes.py", line 33, in <module>
from jax._src import config
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 27, in <module>
from jax._src import lib
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
version = check_jaxlib_version(
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 64, in check_jaxlib_version
raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.
for the sake of showing thoroughness:
changed the jax
and jaxlib
dependency versions to 0.4.28 (I know the version of PJRT is within 0.4.26 but given that isnt working I hoped that perhaps the 0.4.28 or 0.4.27 version may have the PJRT version as well.)
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
• Downgrading jaxlib (0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)
again did the following:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh
performed the test above
$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
also tried
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.28", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
• Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl)
again did the following:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh
performed the test above
$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
I have noticed that the people that have been able to get things working with this fix have been using the M3 chip. perhaps because im using the M1 chip, this could be the issue? Has anyone tried to replicate this on an M1?
just to be clear: Chip: Apple M1 Pro MacOS: Sonoma 14.5
@BeeGass Are you sure your
python3
is anarm64
binary? I've been bitten by this more times than I care for. Tryimport platform platform.machine() # should give you 'arm64'
I got mine (M1 Air, Sonoma) working by doing
python -m pip install jax==0.4.26 jaxlib==0.4.26 python -m pip install jax-metal
after I used #19886 to set up my environment (but make sure that your shell is also
arm64
before doing this, e.g.arch -arm64 zsh
).
Thanks, this works for me. Mine is M3 Pro, Sonoma 14.5.
@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?
@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?
Yeah still same behavior. Am told that the jax version needs to be equal to or higher than 0.4.27 @shuhand0
When I ran
pip install jax==0.4.26 jaxlib==0.4.26
then I think I got success.this works for me on M1 MAX. thanks for sharing!
this works for me on M3. Thank you!!!!!
Description
I ran the Get Started code on the Apple Accelerated JAX training on Mac page, namely:
On running that last line I get the following error:
System info (python version, jaxlib version, accelerator, etc.)
Running
import jax; jax.print_environment_info()
returns the following error:Running the command a second time results in:
More info
I get the same issue on both my M1 Max MacStudio and M1 2020 MacBook Air. Both running Sonoma 14.5