jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.31k stars 2.78k forks source link

Provide wheels for macOS ARM #5501

Closed ericmjl closed 2 years ago

ericmjl commented 3 years ago

Hi all,

I was digging around to see what might need to happen to allow JAX to work on Apple Silicon. Knowing that JAX gets compiled to XLA, my guess here is that XLA would need to be made Apple Silicon-compatible first before JAX could run on it. May I ask, do you all know if there are plans on the XLA team to make that happen, or is it being ignored completely? (Knowing the answer can help me make some decisions on how I should set up my development environment mostly.)

Cheers, Eric

smao-astro commented 2 years ago

Hello @dfm , this is really cool! I am a JAX user considering changing from an Intel MAC to the new M1 pro machine. I have tried to follow the thread here. However, I still can not decide because I do not have much knowledge on compiling JAX and the hardware. So, would you kindly explain a bit more about:

  1. You can build and run JAX on M1 using CPU or both CPU and GPU?
  2. If it only runs on CPU, do you think we have to wait long for running on GPU? (Sure, I won't run heavy experiments on a laptop, just for testing and debugging)

Thanks!

dfm commented 2 years ago

@smao-astro: This is a CPU-only build, and I don't know of plans to get XLA to support Apple GPUs, but I don't follow this very closely.

yashk2810 commented 2 years ago

@dfm upstreamed his patch to TF and I built the macos arm64 wheel using his patch: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

pip install -U pip
pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Can someone see if this wheel works?

Thank you again!

dfm commented 2 years ago

Looks to be working over here. Thanks @yashk2810!

Madder commented 2 years ago

@yashk2810 @dfm GPU acceleration seems to be officially supported on M1 Macos through this: https://developer.apple.com/metal/tensorflow-plugin/ Does that mean that the JAX Macos Arm wheels support GPU ? (either currently or in the future)

yashk2810 commented 2 years ago

It doesn't support GPU right now. There are no plans currently to support the GPU in the future.

yashk2810 commented 2 years ago

Looks to be working over here

Awesome. Thank you @dfm for upstreaming your patch so that we could build jaxlib!

I am going to close this issue as it has been fixed and on the next release, we will have official jaxlib maxos arm64 wheels on pypi.

For now, you can install via:

pip install -U pip
pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl
8bitmp3 commented 2 years ago

nice work @dfm @yashk2810

facero commented 2 years ago

Hi all,

On a M1 chip and python 3.9.5, I'm trying to install Jax and followed the two options presented here ( by @dfm and @yashk2810) but for each I get this error :

(py395) mymachine:~ lg$ pip install -U https://dfm.io/custom-wheels/jaxlib/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform.

Which I don't understand as I have :

Python 3.9.5 (default, May 18 2021, 12:31:01)
[Clang 10.0.0 ] :: Anaconda, Inc. on darwin

and MacOs Big Sur 11.2.3

@quattro did you install python 3.9 via conda ? It says [Clang 10.0.0 ] :: Anaconda, Inc. on darwin Could it be that I'm not having a ARM python version (if this exists) and that's why it doesn't find the wheel ?

Thx for your help

hawkinsp commented 2 years ago

What does

import platform
print(platform.machine())

print?

xhochy commented 2 years ago

Python 3.9.5 (default, May 18 2021, 12:31:01) [Clang 10.0.0 ] :: Anaconda, Inc. on darwin

Anaconda doesn't yet provide packages for M1/ARM, thus this is definitely an x86 build.

Only conda-forge provides Python builds through conda / mamba but here the prompt will look like the following:

Python 3.9.2 | packaged by conda-forge | (default, Feb 21 2021, 05:00:30)
[Clang 11.0.1 ] on darwin

The prompt looks the same between ARM and x86 builds but as you can see at https://anaconda.org/main/python, there is no osx-arm64 from Anaconda yet.

facero commented 2 years ago

What does

import platform
print(platform.machine())

Indeed this seems to be the issue : x86_64 Conda installed a x86 Python version.

So what are my options for installing a python ARM version: conda-forge ? Do I need to specify a version/architecture ? Also does that mean that all packages (e.g. numpy, scipy, matplotlib) in that environment will need to be ARM-compiled version ?

ngold5 commented 2 years ago

@facero I had success using miniforge: https://github.com/conda-forge/miniforge

You can install via Homebrew if that is your package manager of choice. @yashk2810 last reply should do the trick after that.

NightMachinery commented 2 years ago

It doesn't support GPU right now. There are no plans currently to support the GPU in the future.

Can you open an issue for GPU/NE support on Apple ARM so that we can receive the updates easily? (I know there are no plans to support it, but it might still happen one day. And I am sure a lot of people would like to receive some notification if that happened.)

hawkinsp commented 2 years ago

@NightMachinary I think you're looking for https://github.com/google/jax/issues/8074 for Apple GPU support via Metal.

We would need assistance from Apple to support the Neural Engine. At least last time I checked, it does not provide APIs usable by JAX.

nashmathur commented 2 years ago

I am trying to pip install jaxlib on the latest M1 Macbook Pro (2021) with macOS Monterey 12.0.1. However, I get this error:

ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none) ERROR: No matching distribution found for jaxlib

(Please forgive me if I'm doing something stupid since I have migrated to macOS recently)

yashk2810 commented 2 years ago

Try updating your pip version? (pip install -U pip)

nashmathur commented 2 years ago

Yes, it is already up-to-date (21.3.1).

MikeInnes commented 2 years ago

You might be on python 3.8 (the version that comes with anaconda currently). If you switch to 3.9 (eg via homebrew's python) it should work.

nashmathur commented 2 years ago

Amazing! Yes, it was Python 3.8 which was the problem. I upgraded to 3.10 and it works! Thanks a lot.

yiyaz commented 2 years ago

I have tried the solution proposed by dfm and yashk2810 with no success. dfm's approach yields ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.74 (from versions: 0.1.75) ERROR: No matching distribution found for jaxlib==0.1.74

while yashk2810's approach gives me ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform. updating pip to latest version and python to 3.10.2 did not help either

phinate commented 2 years ago

You should be able to just pip install jax jaxlib now -- as the first error states, jaxlib==0.1.75 is now provided on pypi for ARM.

From yashk:

I am going to close this issue as it has been fixed and on the next release, we will have official jaxlib maxos arm64 wheels on pypi.

nlp4whp commented 2 years ago

You should be able to just pip install jax jaxlib now -- as the first error states, jaxlib==0.1.75 is now provided on pypi for ARM.

From yashk:

I am going to close this issue as it has been fixed and on the next release, we will have official jaxlib maxos arm64 wheels on pypi.

THANKS, and would it be supported for py3.8 on maxos arm64?

hawkinsp commented 2 years ago

We don't provide 3.8 wheels for mac arm64, only 3.9 and 3.10. (I think originally we were under the impression that 3.8 was never released for Mac ARM, although I guess that's not true.) I guess we could, though.

dwyatte commented 2 years ago

@hawkinsp please see https://github.com/google/jax/issues/9065 for a specific request for Python 3.8 Mac OS arm64 wheels if it's not a ton of effort for maintainers.

nickreich commented 2 years ago

I had the same problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew and whatever, what worked for my was installing jax and jaxlib via conda-forge directly:

conda install -c conda-forge jaxlib
conda install -c conda-forge jax
gabrieldernbach commented 2 years ago

The conda packages from -c conda-forge did not work for me. Yet pip will fail on installation of the dependency scipy. In the end this was the shortest way I could find

conda create --name venv python=3.10
conda activate venv
conda install -y scipy
pip install jax jaxlib
xhochy commented 2 years ago

@gabrieldernbach Can you enlighten me what didn't work with the conda packages? Since yesterday new versions are up that should work.

gabrielhuang commented 1 year ago

By far the easiest solution for me was to re-install Python using the latest arm64 version of Miniforge (not Miniconda) then pip install jaxlib jax

ddrous commented 1 year ago

By far the easiest solution for me was to re-install Python using the latest arm64 version of Miniforge (not Miniconda) then pip install jaxlib jax

I've tried this, and it works!! This video might be useful for people already using miniconda, that want miniforge on the side or as default.

timhdesilva commented 1 year ago

Hi all,

I just upgraded from an Intel Mac to an M2 Mac and read this thread. What is the best way for my to proceed in terms of installing JAX on the M2? Is it possible to build JAX to work with Apple GPU in addition to CPUs?

Thanks and apologies in advance if I missed this in the discussion above!

kechan commented 1 year ago

I had the same problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew and whatever, what worked for my was installing jax and jaxlib via conda-forge directly:

conda install -c conda-forge jaxlib
conda install -c conda-forge jax

Is this CPU only? How about Apple GPU cores, or maybe even the Neural Engines?