jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.31k stars 2.78k forks source link

Provide wheels for macOS ARM #5501

Closed ericmjl closed 2 years ago

ericmjl commented 3 years ago

Hi all,

I was digging around to see what might need to happen to allow JAX to work on Apple Silicon. Knowing that JAX gets compiled to XLA, my guess here is that XLA would need to be made Apple Silicon-compatible first before JAX could run on it. May I ask, do you all know if there are plans on the XLA team to make that happen, or is it being ignored completely? (Knowing the answer can help me make some decisions on how I should set up my development environment mostly.)

Cheers, Eric

gerdm commented 3 years ago

@yashk2810,

I uninstalled jaxlib and installed the wheel you shared, but I get the following error

pip install --force-reinstall ~/Downloads/jaxlib-0.1.72-cp39-none-macosx_11_0_arm64.whl
# ...
Successfully installed absl-py-0.13.0 flatbuffers-2.0 jaxlib-0.1.72 numpy-1.21.2 scipy-1.7.1 six-1.16.0
(miniforge3) ❯ pip install "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# ...
Successfully installed jax-0.2.20 jaxlib-0.1.71
(miniforge3) ❯ ipython                                                               
Python 3.9.6 | packaged by conda-forge | (default, Jul 11 2021, 03:35:11) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.26.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax
/Users/gerardoduran/miniforge3/lib/python3.9/site-packages/jax/lib/__init__.py:31:
 UserWarning: JAX on Mac ARM machines is experimental
 and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-1-cb15c4215ef7> in <module>
----> 1 import jax

~/miniforge3/lib/python3.9/site-packages/jax/__init__.py in <module>
     35 # We want the exported object to be the class, so we first import the module
     36 # to make sure a later import doesn't overwrite the class.
---> 37 from . import config as _config_module
     38 del _config_module
     39 

~/miniforge3/lib/python3.9/site-packages/jax/config.py in <module>
     16 
     17 # flake8: noqa: F401
---> 18 from jax._src.config import config

~/miniforge3/lib/python3.9/site-packages/jax/_src/config.py in <module>
     25 import warnings
     26 
---> 27 from jax import lib
     28 from jax.lib import jax_jit
     29 

~/miniforge3/lib/python3.9/site-packages/jax/lib/__init__.py in <module>
     72 
     73 from jaxlib import xla_client
---> 74 from jaxlib import lapack
     75 from jaxlib import pocketfft
     76 

jaxlib/lapack.pyx in init lapack()

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/__init__.py in <module>
    193 """  # noqa: E501
    194 
--> 195 from .misc import *
    196 from .basic import *
    197 from .decomp import *

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/misc.py in <module>
      1 import numpy as np
      2 from numpy.linalg import LinAlgError
----> 3 from .blas import get_blas_funcs
      4 from .lapack import get_lapack_funcs
      5 

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/blas.py in <module>
    211 import functools
    212 
--> 213 from scipy.linalg import _fblas
    214 try:
    215     from scipy.linalg import _cblas

ImportError: dlopen(/Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so, 2): no suitable image found.  Did find:
        /Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so: mach-o, but wrong architecture
        /Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so: mach-o, but wrong architecture
yashk2810 commented 3 years ago

This answer seems to imply that your python installation went wrong somewhere: https://stackoverflow.com/questions/39477023/error-mach-o-but-wrong-architecture-after-installing-anaconda-on-mac

Can you check that and try again?

Thank you for trying out. I (and the JAX team) appreciate it :)

yashk2810 commented 3 years ago

There are lots of users who have hit the mach-o but wrong architecture error so doesn't look like a JAX issue to me.

dfm commented 3 years ago

@yashk2810: Thanks! Unfortunately I find that your wheel built at head reproduces the same LLVM error on my machine. I'm hoping to get a chance to try to fix it on my side, but haven't had a moment yet.

yashk2810 commented 3 years ago

Interesting, can you paste your log and the way you installed it and the OS?

dfm commented 3 years ago

Sure! Here you go:

conda create -n jax-test python=3.9 numpy scipy
conda activate jax-test
python -m pip install https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.72-cp39-none-macosx_11_0_arm64.whl
python -m pip install jax

Then in Python:

Python 3.9.7 | packaged by conda-forge | (default, Sep  2 2021, 17:55:16)
[Clang 11.1.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jaxlib
>>> jaxlib.__version__
'0.1.72'
>>> import jax.numpy as jnp
/opt/homebrew/Caskroom/miniforge/base/envs/jax-test/lib/python3.9/site-packages/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
>>> jnp.sqrt(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      python

(All other jnp functions I've tried fail with the same error, we're not limited to sqrt...)

Output from 'conda env export': ``` name: jax-test channels: - conda-forge dependencies: - ca-certificates=2021.5.30=h4653dfc_0 - libblas=3.9.0=11_osxarm64_openblas - libcblas=3.9.0=11_osxarm64_openblas - libcxx=12.0.1=h168391b_0 - libgfortran=5.0.0.dev0=11_0_1_hf114ba7_23 - libgfortran5=11.0.1.dev0=hf114ba7_23 - liblapack=3.9.0=11_osxarm64_openblas - libopenblas=0.3.17=openmp_h5dd58f0_1 - llvm-openmp=12.0.1=hf3c4609_1 - ncurses=6.2=h9aa5885_4 - numpy=1.21.2=py39h1f3b974_0 - openssl=1.1.1l=h3422bc3_0 - pip=21.2.4=pyhd8ed1ab_0 - python=3.9.7=h54d631c_0_cpython - python_abi=3.9=2_cp39 - readline=8.1=hedafd6a_0 - scipy=1.7.0=py39h5060c3b_0 - setuptools=58.0.4=py39h2804cbe_0 - sqlite=3.36.0=h72a2b83_1 - tk=8.6.11=he1e0b03_1 - tzdata=2021a=he74cb21_1 - wheel=0.37.0=pyhd8ed1ab_1 - xz=5.2.5=h642e427_1 - zlib=1.2.11=h31e879b_1009 - pip: - absl-py==0.13.0 - flatbuffers==2.0 - jax==0.2.20 - jaxlib==0.1.72 - opt-einsum==3.3.0 - six==1.16.0 prefix: /opt/homebrew/Caskroom/miniforge/base/envs/jax-test ```
dfm commented 3 years ago

@yashk2810: I did a tiny bit of digging and I think that @hawkinsp's worries were probably right. Knowing very little about how the build infrastructure works, my best guess is that the relevant change is that TF used to configure llvm-project manually:

https://github.com/tensorflow/tensorflow/blob/4039feeb743bc42cd0a3d8146ce63fc05d23eb8d/third_party/llvm/llvm.bzl#L310-L317

But now this is delegated to the bazel support in llvm-project directly, which doesn't seem to correctly handle this target. In particular, when compiling any of the LLVM targets, the CMake variables are no longer set correctly. For example, for the dependencies of jaxlib v0.1.70, the build variable LLVM_NATIVE_ARCH=AArch64 was set correctly, but now it is set using -DLLVM_NATIVE_ARCH="X86".

Anyways, this is probably TMI here, but I'd say that it looks like the issue lives pretty high up the tree of dependencies!

hawkinsp commented 3 years ago

Yes, that seems right. This logic in particular looks wrong: https://github.com/llvm/llvm-project/blob/81d5412439efd0860c0a8dd51b831204f118d485/utils/bazel/llvm-project-overlay/llvm/config.bzl#L78

dfm commented 3 years ago

I can confirm that that logic is the culprit. If I swap out that line as follows (this also isn't the right logic, but it was a test):

-    "@bazel_tools//src/conditions:darwin": native_arch_defines("X86", "x86_64-unknown-darwin"),
+    "@bazel_tools//src/conditions:darwin": native_arch_defines("AArch64", "arm64-apple-darwin"),

Then jaxlib seems to works as expected. (For reference, here's the patch that I applied to the jax source at v0.1.71 that seemed to propagate the change I wanted: https://gist.github.com/dfm/bc2cf413bb4ad0b1d6fb11a96a406ef4)

xhochy commented 3 years ago

@dfm Any idea on how to trigger that error only using jaxlib? That would be nice as a test in the conda package of jaxlib to verify that everything is working (jax and jaxlib are built separately on conda-forge).

hawkinsp commented 3 years ago

@dfm Can you try applying the following patch to LLVM and verifying the resulting wheel works for you?

diff --git a/utils/bazel/llvm-project-overlay/llvm/config.bzl b/utils/bazel/llvm-project-overlay/llvm/config.bzl
index 514f79bcf2b6..8a8e54e844a7 100644
--- a/utils/bazel/llvm-project-overlay/llvm/config.bzl
+++ b/utils/bazel/llvm-project-overlay/llvm/config.bzl
@@ -75,7 +75,10 @@ os_defines = select({
 # TODO: We should split out host vs. target here.
 llvm_config_defines = os_defines + select({
     "@bazel_tools//src/conditions:windows": native_arch_defines("X86", "x86_64-pc-win32"),
-    "@bazel_tools//src/conditions:darwin": native_arch_defines("X86", "x86_64-unknown-darwin"),
+    "@bazel_tools//src/conditions:darwin": select({
+         "@bazel_tools//platforms:arm": native_arch_defines("AArch64", "arm64-apple-darwin"),
+         "//conditions:default": native_arch_defines("X86", "x86_64-apple-darwin"),
+     }),
     "@bazel_tools//src/conditions:linux_aarch64": native_arch_defines("AArch64", "aarch64-unknown-linux-gnu"),
     "//conditions:default": native_arch_defines("X86", "x86_64-unknown-linux-gnu"),
 }) + [
dfm commented 3 years ago

@hawkinsp: I gave that a shot and it fails because, unless I'm misunderstanding something, bazel doesn't seem to support nested selects like this (I can pull up the exact error, but I didn't save it originally). Unfortunately it also looks like there is a bug in the released versions of bazel which means that /src/conditions:darwin_arm64 doesn't work. The simplest diff that I could find to implement what I think is the correct logic (mainly copied from here) was: https://gist.github.com/dfm/845dfbd3dc1c17f75e7cb0cba7b0febb

hawkinsp commented 3 years ago

@dfm Yeah I worried that might not work because of the nested selects. Your version looks reasonable to me. Do you want to send that to upstream LLVM (I can, if you don't want to, but you did the work!)? That's all we need to do here to get JAX fixed!

dfm commented 3 years ago

@hawkinsp: Sure, thanks! I'm happy to see if I can figure out the llvm review system and report back :D

saminehbagheri commented 3 years ago

Confirmed that it works 🎉 . Built with Bazel master, and https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

>>> import platform
>>> platform.uname()
uname_result(system='Darwin', node='Josefs-MBP-2.lan', release='20.3.0', version='Darwin Kernel Version 20.3.0: Thu Jan 21 00:06:51 PST 2021; root:xnu-7195.81.3~1/RELEASE_ARM64_T8101', machine='arm64')
>>> from jax.lib import xla_client as xc
>>> xops = xc.ops
>>> c = xc.XlaBuilder("simple_scalar")
>>> param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())
>>> x = xops.Parameter(c, 0, param_shape)
>>> y = xops.Sin(x)
>>> computation = c.Build()
>>> cpu_backend = xc.get_local_backend("cpu")
2021-03-07 09:26:16.684549: W external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
>>> compiled_computation = cpu_backend.compile(computation)
>>> host_input = np.array(3.0, dtype=np.float32)
>>> device_input = cpu_backend.buffer_from_pyval(host_input)
>>> device_out = compiled_computation.execute([device_input ,])
>>> device_out[0].to_py()
array(0.14112, dtype=float32)

@mattjj @hawkinsp If you want a PR I would be happy to create one, but maybe it makes more sense to wait until bazel has released a working native arm64 build and tensorflow have the necessary code in master.

Hi there,

Would you mind documenting the steps you've done to resolve the problem? I have a similar issue. So I could manage to install numpyro, Jax and jaxlib but when I import the packages I get the following warning:

/opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jax/lib/init.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems. warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

and an ImportError: ImportError: dlopen(/opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE Referenced from: /opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jaxlib/xla_extension.so

hawkinsp commented 3 years ago

@saminehbagheri That's a warning, not an error. JAX should work as normal, but we just want you to be aware it's pretty minimally tested at this point on ARM.

saminehbagheri commented 3 years ago

@saminehbagheri That's a warning, not an error. JAX should work as normal, but we just want you to be aware it's pretty minimally tested at this point on ARM.

Thanks for the reply. I reformulated my post. True that's a warning but I also get an import error for a missing xla_extension.so.

annakoop commented 3 years ago

I'm getting the same import error (and had to set global variables as described here, for GRPCIO to install).

ImportError: dlopen(/Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE E Referenced from: /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so E Expected in: flat namespace E in /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so

jotsif commented 3 years ago

@annakoop @saminehbagheri: that import error is most likely because of jax and jaxlib versions that dont match. Since the llvm problem discussed above has not been fixed make sure you build jaxlib 0.1.70 and jax 0.2.19

johnjmolina commented 3 years ago

Thank you all for the detailed information, this is very helpful.

I was running jax/jaxlib under emulation on the m1 and started seeing this same import error with the recent versions (including the 0.1.70 and 0.2.19 combo). However, it does work if I install jaxlib 0.1.61 and jax 0.2.10 (...but then I can't use jaxopt). Does anyone know why this is happening, should we not expect to run JAX under emulation going forward?

Thanks in advance!

dfm commented 3 years ago

Quick update on the LLVM issue. I did get a tiny patch merged that should start getting us there. They sensibly did not want to merge the full patch that I shared above because it shouldn't be necessary. But, I'm now finding that there's something about the tensorflow bazel configuration which means that LLVM can't seem to figure out the correct platform even after updating to this commit and using the most recent version of bazel. I've gone down a real rabbit hole with this one and I'm still coming up empty, unfortunately!

annakoop commented 3 years ago

Still having the xla issue with jaxlib 0.1.70 and jax 0.2.19, testing some different configurations...

Michael-tehc commented 3 years ago

I'm getting the same import error (and had to set global variables as described here, for GRPCIO to install).

ImportError: dlopen(/Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE E Referenced from: /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so E Expected in: flat namespace E in /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so

Not sure if this is the right issue, but I'm getting the same error on x86 macOS Python 3.9.7.

I've just installed JAX:

  absl-py            conda-forge/noarch::absl-py-0.14.0-pyhd8ed1ab_0
  jax                conda-forge/noarch::jax-0.2.21-pyhd8ed1ab_0
  jaxlib             conda-forge/osx-64::jaxlib-0.1.71-py39h757cd7f_0
  opt_einsum         conda-forge/noarch::opt_einsum-3.3.0-pyhd8ed1ab_1
  python-flatbuffers conda-forge/noarch::python-flatbuffers-2.0-pyhd8ed1ab_0
quattro commented 3 years ago

Any updates or progress on this front? I patched the LLVM according to the above diffs with no success.

8bitmp3 commented 3 years ago

M1 Pro and M1 Max 🥲

yashk2810 commented 3 years ago

jaxlib wheels for 0.1.73 is live: https://pypi.org/project/jaxlib/0.1.73/#files

Can you try it out and see if the issue is fixed?

jotsif commented 3 years ago

I get the 'cyclone' is not a recognized processor for this target (ignoring processor) on that wheel

quattro commented 3 years ago

Same error here, as well.

Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: a = jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
ngold5 commented 3 years ago

I get the same error as above:

>>> rng_key = random.PRNGKey(0)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
hawkinsp commented 3 years ago

I think the LLVM fix may not have landed, or it was reverted.

dfm commented 3 years ago

Yeah - TensorFlow patches out the LLVM fix because it's incompatible with their macos x86_64 builds for reasons that I don't totally understand (these discussions are happening somewhere that I don't have access to). I didn't have much luck working around this in my experiments so I'm just hoping that it gets sorted out upstream eventually :D

For now, the jaxlib==0.1.70 wheels that I built are working just fine on my M1, so I've just been using those:

python -m pip install jax jaxlib==0.1.70 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

Hope this helps!

ngold5 commented 3 years ago

Yeah - TensorFlow patches out the LLVM fix because it's incompatible with their macos x86_64 builds for reasons that I don't totally understand (these discussions are happening somewhere that I don't have access to). I didn't have much luck working around this in my experiments so I'm just hoping that it gets sorted out upstream eventually :D

For now, the jaxlib==0.1.70 wheels that I built are working just fine on my M1, so I've just been using those:

python -m pip install jax jaxlib==0.1.70 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

Hope this helps!

This is the solution to use until the fix! Thank you very much

yashk2810 commented 3 years ago

I just submitted a fix upstream to TF: https://github.com/tensorflow/tensorflow/commit/cd76ed3114f5d3e5f387dbc04de63891da958861

I'll build jaxlib again to see if the fix works.

Or if someone can build it before I do and confirm it works, it would be very much appreciated 😃

yashk2810 commented 3 years ago

Looks like it doesn't work.

erwincoumans commented 3 years ago

Same error, I think the uploaded wheel for 0.1.73 was never tested.

'cyclone' is not a recognized processor for this target (ignoring processor)

I'm hoping that Macbook ARM M1/M1X is going to be a proper supported platform. Same for Windows.

hawkinsp commented 3 years ago

@erwincoumans As we mentioned above, we don't actually have any M1 hardware ourselves on the team. So we can't test anything. The M1 build is community supported at the moment.

(I would imagine we will support the M1 build ourselves in the future, but we can't yet.)

yashk2810 commented 3 years ago

Sorry about the breakages but we don't have capability to test as Peter said.

But thank you to everyone here who is doing the testing for us. We really appreciate it.

I am working on a fix right now to get this resolved.

yashk2810 commented 3 years ago

So a new jaxlib wheel is ready: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Can someone please try it out on their Mac M1 and see if it works?

pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Thank you!

ngold5 commented 3 years ago

Still getting the same cyclone error on a fresh install. Note I use the nightly build of scipy.

quattro commented 3 years ago

Thanks for your effort, @yashk2810. Looks like updated wheel won't install on M1.

(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform.
yashk2810 commented 3 years ago

Are you running python 3.9 wherever you are installing it?

Also can you try pip install -U pip and then run the install command?

quattro commented 3 years ago

Hi @yashk2810 , yes please see below:

(base) nicholas@atlanta ~ % pip install -U pip
Requirement already satisfied: pip in ./miniforge3/lib/python3.9/site-packages (21.3)
(base) nicholas@atlanta ~ % python -V
Python 3.9.5
quattro commented 3 years ago

Oh weird. Something must have been broken in my terminal. Launching a fresh zshell installation worked, however the wheel is still throwing the same error as earlier.

(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Collecting jaxlib==0.1.74
  Downloading https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl (36.9 MB)
     |████████████████████████████████| 36.9 MB 5.8 MB/s             
Requirement already satisfied: numpy>=1.18 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.21.2)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (2.0)
Requirement already satisfied: absl-py in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (0.14.1)
Requirement already satisfied: scipy in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.7.1)
Requirement already satisfied: six in ./miniforge3/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.74) (1.16.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.70
    Uninstalling jaxlib-0.1.70:
      Successfully uninstalled jaxlib-0.1.70
Successfully installed jaxlib-0.1.74
(base) nicholas@atlanta ~ % ipython
Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: a = jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      ipython
ngold5 commented 3 years ago

Are you running python 3.9 wherever you are installing it?

Also can you try pip install -U pip and then run the install command?

Also on 3.9

jotsif commented 3 years ago

That 0.1.74 wheel won't run here either.

Thread 0 Crashed:: Dispatch queue: com.apple.main-thread
0   libsystem_kernel.dylib          0x0000000183cc0e68 __pthread_kill + 8
1   libsystem_pthread.dylib         0x0000000183cf343c pthread_kill + 292
2   libsystem_c.dylib               0x0000000183c3b454 abort + 124
3   xla_extension.so                0x00000001054001cc llvm::report_fatal_error(llvm::Twine const&, bool) + 452
4   xla_extension.so                0x0000000105400008 llvm::report_fatal_error(char const*, bool) + 56
5   xla_extension.so                0x0000000103cba204 llvm::X86Subtarget::initSubtargetFeatures(llvm::StringRef, llvm::StringRef, llvm::StringRef) + 480
6   xla_extension.so                0x0000000103cba3bc llvm::X86Subtarget::X86Subtarget(llvm::Triple const&, llvm::StringRef, llvm::StringRef, llvm::StringRef, llvm::X86TargetMachine const&, llvm::MaybeAlign, unsigned int, unsigned int) + 356
7   xla_extension.so                0x0000000103cbb8c0 llvm::X86TargetMachine::getSubtargetImpl(llvm::Function const&) const + 1184
8   xla_extension.so                0x0000000103cbba74 llvm::X86TargetMachine::getTargetTransformInfo(llvm::Function const&) + 92
9   xla_extension.so                0x0000000104ee9c70 llvm::TargetTransformInfoWrapperPass::getTTI(llvm::Function const&) + 60
yashk2810 commented 2 years ago

I updated the 0.1.74 wheel with a different patch in upstream TF: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Can someone see if this works?

pip install -U pip
pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Thank you!

quattro commented 2 years ago

Hi @yashk2810 , I really appreciate you looking into this. Unfortunately, seems like the same error crops up:

(base) nicholas@atlanta ~ % pip install -U pip
Requirement already satisfied: pip in ./miniforge3/lib/python3.9/site-packages (21.3)
Collecting pip
  Downloading pip-21.3.1-py3-none-any.whl (1.7 MB)
     |████████████████████████████████| 1.7 MB 2.9 MB/s            
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.3
    Uninstalling pip-21.3:
      Successfully uninstalled pip-21.3
Successfully installed pip-21.3.1
(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl
Collecting jaxlib==0.1.74
  Downloading https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl (37.0 MB)
     |████████████████████████████████| 37.0 MB 5.7 MB/s             
Requirement already satisfied: scipy in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.7.1)
Requirement already satisfied: absl-py in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (0.14.1)
Requirement already satisfied: numpy>=1.18 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.21.2)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (2.0)
Requirement already satisfied: six in ./miniforge3/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.74) (1.16.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.70
    Uninstalling jaxlib-0.1.70:
      Successfully uninstalled jaxlib-0.1.70
Successfully installed jaxlib-0.1.74
(base) nicholas@atlanta ~ %  
(base) nicholas@atlanta ~ % ipython
Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      ipython
dfm commented 2 years ago

I've now gotten the main branch version of jaxlib building on my M1 and cross-compiling on GitHub Actions. It seems to be running fine for me, and you can try it using:

python -m pip install jax jaxlib==0.1.74 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

I'm still horrifyingly patching LLVM, via TensorFlow (here's the diff). It's a bit tricky from the outside to synchronize all the moving parts, but @yashk2810 if you want to chat offline, I might be able to give some tips for getting this to work without spamming this thread that's already pretty noisy. My email is on my GitHub profile and website if you're interested!

quattro commented 2 years ago

@dfm can confirm it works here--fantastic!

ngold5 commented 2 years ago

@dfm it works - fantastic job!