Closed talmo closed 1 year ago
Does this work with M1 CPUs? I had originally suggested conda-based installation because directly install jax via conda-forge had been posted as one solution. But this PR seems to use pip install for mac.
Yes, i tested it on an M1. The pip wheel works fine for CPU, but Metal/Neural Engine acceleration on the same chip is blocked by Jax support (https://github.com/google/jax/issues/8074).
I haven't tested the conda-forge version, but for all intents and purposes they should both work here for CPU-only mode.
Just tested this with conda on my M1 -- installing with conda env create -f conda_envs/environment.mac_cpu.yml
gave the same AVX error as previous.
I had to move the jax and jaxlib lines up into the conda part of the yaml, instead of the pip part, like so:
dependencies:
- python=3.9
- pytables
- openblas
- jax==0.3.22
- jaxlib==0.3.22
- numpy=1.23.5
- pip
- pip:
- "jax-moseq==0.0.1"
- "--editable=../"
- jupyterlab
then installing with that command works fine.
Can also confirm that pip installing (cpu version) totally independent of conda (in a venv) also works (option 1 in the readme).
However, nb, option 1 ends up with the most recent versions of jax + jaxlib (0.4.8), among other package version differences. Maybe consider either pinning the packages in the pip version or un-pinning the conda ones so people don't end up with two different underlying libraries?
Adds conda environment for installing with CPU-only Jax and associated instructions.