patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

Difficulty installing diffrax on apple silicon #429

Open cgiovanetti opened 1 month ago

cgiovanetti commented 1 month ago

I recently ran a hack session for a code that uses diffrax, but had some trouble getting some participants set up with JAX/diffrax, especially those using M1 or M3 macs (M2 seemed to work fine).

I have found it safest to install JAX using conda when using apple silicon, and so that's the first step I suggested. However, there doesn't seem to be a conda installation of diffrax available, so we needed to pip install diffrax. But diffrax would then uninstall the version of JAX we installed with conda, and reinstall a different version with pip. This caused participants with this hardware to get JAX-related errors when trying to run JAX code (This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.--this is usually the issue I work around by installing JAX with conda). Some participants were able to then uninstall the version of JAX installed by diffrax, and reinstall it again with conda. Others tried this and got errors in diffrax after doing so (partial stack trace included below).

I only have access to an M2 mac and so it's difficult for me to replicate the issue. In fact, it's difficult to replicate the issue at all, because participants had varying levels of success even on similar hardware (we tried many different JAX/python versions for these participants--on M1, downgrading to python 3.11 seemed to help, but did not consistently help on M3). If it's possible to a) not have diffrax install its own JAX with pip, and/or b) install diffrax with conda, and/or c) provide installation best practices for diffrax with apple silicon, I'd hope that would help with some of these headaches down the line.


With Python 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] on darwin, one participant with an M3 got the following error after reinstalling JAX with conda:

drive-download-20240523T232101Z-001 python3 test_SBBN.py
Traceback (most recent call last):
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_root_finder/_verychord.py", line 83, in init
    init_state = options["init_state"]
                 ~~~~~~~^^^^^^^^^^^^^^
KeyError: 'init_state'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Users/annikapeter/Dropbox/LINX/drive-download-20240523T232101Z-001/test_SBBN.py", line 51, in <module>
    vJax_res_raw = abundance_model_PRIMAT_2022(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/equinox/_module.py", line 1189, in __call__
    return self.func(*self.args, *args, **kwargs, **self.keywords)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/annikapeter/Dropbox/LINX/drive-download-20240523T232101Z-001/LINX/abundances.py", line 280, in __call__
    sol = diffeqsolve(
          ^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_integrate.py", line 916, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
                             ^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_adjoint.py", line 288, in loop
    final_state = self._loop(
                  ^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_integrate.py", line 439, in loop
    _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_integrate.py", line 240, in body_fun_aux
    (y, y_error, dense_info, solver_state, solver_result) = solver.step(
                                                            ^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_solver/runge_kutta.py", line 1099, in step
    jac_f = self.root_finder.init(  # pyright: ignore
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/diffrax/_root_finder/_verychord.py", line 87, in init
    init_later_state = self.linear_solver.init(jac, options={})
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/lineax/_solve.py", line 627, in init
    return token, _lookup(token).init(operator, options)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/lineax/_solver/lu.py", line 53, in init
    lu = jsp.linalg.lu_factor(operator.as_matrix())
         ^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/jax/_src/lazy_loader.py", line 39, in __getattr__
    return importlib.import_module(f"{package_name}.{name}")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/linx/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/jax/scipy/linalg.py", line 18, in <module>
    from jax._src.scipy.linalg import (
  File "/opt/anaconda3/envs/linx/lib/python3.11/site-packages/jax/_src/scipy/linalg.py", line 403, in <module>
    @_wraps(scipy.linalg.tril)
            ^^^^^^^^^^^^^^^^^
AttributeError: module 'scipy.linalg' has no attribute 'tril'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
AbhinavMir commented 1 month ago

Does pip install diffrax --no-deps also not help?

cgiovanetti commented 1 month ago

No--we had to go through the other dependencies one-by-one, but still ended up with the same scipy.linalg error at the end of it.

patrick-kidger commented 1 month ago

It looks like the error you're getting here is ultimately coming from having a newer version of JAX (which expects scipy.linalg.tril to exist) with an older version of SciPy (from which scipy.linalg.tril has been removed). A short search turns up https://github.com/google/jax/discussions/18995 for example.

As such this isn't a Diffrax thing at all. I think you first need to figure out how to install JAX in whatever is your preferred way for this hardware. That's not something I have any experience with, though perhaps these instructions may help.

Once you have done that, and verified that import jax.scipy.linalg works correctly: then think about installing Diffrax. If need be that may be via pip install diffrax --no-deps to be sure that it won't adjust your existing JAX installation.

(For what it's worth I'm on an M2 and I just pip install jax[cpu] without issue.)

cgiovanetti commented 1 month ago

Okay--is diffrax uninstalling/reinstalling JAX because of a version issue then? i.e., the conda installed JAX is too new/old? Or is there some other reason it might not see the conda installation of JAX?

patrick-kidger commented 1 month ago

Diffrax doesn't touch your JAX installation at all.

pip or conda might depending on how you use them.

cgiovanetti commented 1 month ago

Maybe the sharper question to ask is: if I do not provide the --no-deps flag when pip installing diffrax, what version(s) of JAX must I have for JAX not to be reinstalled during diffrax installation?

One participant successfully import jax'd between installing JAX and installing diffrax in the installation instructions I gave them--unfortunately I can't check back and see if they could specifically import jax.scipy.linalg, though this might be a troubleshooting tip we'll distribute to users.

patrick-kidger commented 1 month ago

Anything compatible with the JAX version listed in the pyproject.toml of Diffrax, or the pyproject.toml of its dependencies.

As a practical matter, right now I think that means >=0.4.23,!=0.4.27.

johannahaffner commented 3 weeks ago

Make sure people have a new-ish version of Python and use virtual environments. Just do something like

python3 -m venv env
source env/bin/activate
pip3 install jax jaxlib diffrax

and you'll be fine. You will also not have to worry about which version of whatever else people have on their PCs.

FWIW, pip install does seem to be more reliable than conda - this is also my experience from helping people use diffrax for coursework.