google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

Jax metal failed to install #19886

Open Kubiczek36 opened 4 months ago

Kubiczek36 commented 4 months ago

Description

Using the instructions on the pip website the jax_metal failed to install

(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda create -n jax_metal python=3.10          
Channels:
 - defaults
Platform: osx-64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal

  added / updated specs:
    - python=3.10

The following NEW packages will be INSTALLED:

  bzip2              pkgs/main/osx-64::bzip2-1.0.8-h1de35cc_0 
  ca-certificates    pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0 
  libffi             pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0 
  ncurses            pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0 
  openssl            pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0 
  pip                pkgs/main/osx-64::pip-23.3.1-py310hecd8cb5_0 
  python             pkgs/main/osx-64::python-3.10.13-h5ee71fb_0 
  readline           pkgs/main/osx-64::readline-8.2-hca72f7f_0 
  setuptools         pkgs/main/osx-64::setuptools-68.2.2-py310hecd8cb5_0 
  sqlite             pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0 
  tk                 pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0 
  tzdata             pkgs/main/noarch::tzdata-2023d-h04d1e81_0 
  wheel              pkgs/main/osx-64::wheel-0.41.2-py310hecd8cb5_0 
  xz                 pkgs/main/osx-64::xz-5.4.5-h6c40b1e_0 
  zlib               pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0 

Proceed ([y]/n)? 

Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
#     $ conda activate jax_metal
#
# To deactivate an active environment, use
#
#     $ conda deactivate

(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda activate jax_metal
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install -U pip                       
Requirement already satisfied: pip in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (23.3.1)
Collecting pip
  Using cached pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Using cached pip-24.0-py3-none-any.whl (2.1 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.1
    Uninstalling pip-23.3.1:
      Successfully uninstalled pip-23.3.1
Successfully installed pip-24.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install numpy                        
Collecting numpy
  Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl (20.6 MB)
Installing collected packages: numpy
Successfully installed numpy-1.26.4
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install jax-metal                    
Collecting jax-metal
  Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl.metadata (1.4 kB)
Requirement already satisfied: wheel~=0.35 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax-metal) (0.41.2)
Collecting six>=1.15.0 (from jax-metal)
  Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting jax==0.4.20 (from jax-metal)
  Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.20 (from jax-metal)
  Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl.metadata (2.1 kB)
Collecting ml-dtypes>=0.2.0 (from jax==0.4.20->jax-metal)
  Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl.metadata (20 kB)
Requirement already satisfied: numpy>=1.22 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax==0.4.20->jax-metal) (1.26.4)
Collecting opt-einsum (from jax==0.4.20->jax-metal)
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting scipy>=1.9 (from jax==0.4.20->jax-metal)
  Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl.metadata (60 kB)
Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl (54.6 MB)
Using cached jax-0.4.20-py3-none-any.whl (1.7 MB)
Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl (82.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.6/82.6 MB 3.5 MB/s eta 0:00:00
Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl (389 kB)
Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl (38.9 MB)
Installing collected packages: six, scipy, opt-einsum, ml-dtypes, jaxlib, jax, jax-metal
Successfully installed jax-0.4.20 jax-metal-0.0.5 jaxlib-0.4.20 ml-dtypes-0.3.2 opt-einsum-3.3.0 scipy-1.12.0 six-1.16.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/config.py", line 15, in <module>
    from jax._src.config import config as _deprecated_config  # noqa: F401
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 83, in <module>
    cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

System info (python version, jaxlib version, accelerator, etc.)

Macbook Air M2 Macos Sonoma 14.3.1 (23D60) Python 3.10

shuhand0 commented 4 months ago

Based on the packages, it is AMD GPU? Could you try a venv with python=3.9?

curlup commented 4 months ago

Reproduces on my m2 mac. with both py 3.10.6 and 3.9.13

curlup commented 4 months ago

Tried jax==0.4.11 jaxlib==0.4.11 jax-metal==0.0.4 - same thing

shuhand0 commented 3 months ago

Haven't been able to reproduce the issue. The below config shows an installation and verification result: ProductName: macOS ProductVersion: 14.4

The following NEW packages will be INSTALLED:

  ca-certificates    pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0 
  libcxx             pkgs/main/osx-64::libcxx-14.0.6-h9765a3e_0 
  libffi             pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0 
  ncurses            pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0 
  openssl            pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0 
  pip                pkgs/main/osx-64::pip-23.3.1-py39hecd8cb5_0 
  python             pkgs/main/osx-64::python-3.9.18-h5ee71fb_0 
  readline           pkgs/main/osx-64::readline-8.2-hca72f7f_0 
  setuptools         pkgs/main/osx-64::setuptools-68.2.2-py39hecd8cb5_0 
  sqlite             pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0 
  tk                 pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0 
  tzdata             pkgs/main/noarch::tzdata-2024a-h04d1e81_0 
  wheel              pkgs/main/osx-64::wheel-0.41.2-py39hecd8cb5_0 
  xz                 pkgs/main/osx-64::xz-5.4.6-h6c40b1e_0 
  zlib               pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0 
Package            Version
------------------ -------
importlib_metadata 7.0.2
jax                0.4.20
jax-metal          0.0.5
jaxlib             0.4.20
ml-dtypes          0.3.2
numpy              1.26.4
opt-einsum         3.3.0
pip                24.0
scipy              1.12.0
setuptools         68.2.2
six                1.16.0
wheel              0.41.2
zipp               3.17.0
python -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-08 17:33:36.946600: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro Vega 20

systemMemory: 32.00 GB
maxCacheSize: 1.99 GB

[0 1 2 3 4 5 6 7 8 9]
curlup commented 3 months ago

Right, i think i was able to figure it out - in my case it was due python being i386 arch and not arm64. After switching arch and installing native python, it worked.

phisanti commented 3 months ago

Right, i think i was able to figure it out - in my case it was due python being i386 arch and not arm64. After switching arch and installing native python, it worked.

I have just tried to install following the instructions in the apple website (https://developer.apple.com/metal/jax/) and it failed. Same error than everyone here in a M2. How did you switched your native python3?

I have just ran the following code:

import platform

# Check the machine architecture
machine = platform.machine()

if machine == 'arm64':
    print("Your Python version is ARM64")
elif machine == 'i386':
    print("Your Python version is i386 (32-bit)")
elif machine == 'x86_64':
    print("Your Python version is x86_64 (64-bit)")
else:
    print(f"Unknown machine architecture: {machine}")

and the print out is:

Your Python version is x86_64 (64-bit)
curlup commented 3 months ago

@phisanti you switch in you CLI with arch command, then you install python afresh (it will be a different python) and go with jax m install instruct from apple.

phisanti commented 2 months ago

@curlup thanks for the tip. It worked for me!