Open cgiovanetti opened 1 month ago
Does pip install diffrax --no-deps
also not help?
No--we had to go through the other dependencies one-by-one, but still ended up with the same scipy.linalg error at the end of it.
It looks like the error you're getting here is ultimately coming from having a newer version of JAX (which expects scipy.linalg.tril
to exist) with an older version of SciPy (from which scipy.linalg.tril
has been removed). A short search turns up https://github.com/google/jax/discussions/18995 for example.
As such this isn't a Diffrax thing at all. I think you first need to figure out how to install JAX in whatever is your preferred way for this hardware. That's not something I have any experience with, though perhaps these instructions may help.
Once you have done that, and verified that import jax.scipy.linalg
works correctly: then think about installing Diffrax. If need be that may be via pip install diffrax --no-deps
to be sure that it won't adjust your existing JAX installation.
(For what it's worth I'm on an M2 and I just pip install jax[cpu]
without issue.)
Okay--is diffrax uninstalling/reinstalling JAX because of a version issue then? i.e., the conda installed JAX is too new/old? Or is there some other reason it might not see the conda installation of JAX?
Diffrax doesn't touch your JAX installation at all.
pip
or conda
might depending on how you use them.
Maybe the sharper question to ask is: if I do not provide the --no-deps
flag when pip installing diffrax, what version(s) of JAX must I have for JAX not to be reinstalled during diffrax installation?
One participant successfully import jax
'd between installing JAX and installing diffrax in the installation instructions I gave them--unfortunately I can't check back and see if they could specifically import jax.scipy.linalg
, though this might be a troubleshooting tip we'll distribute to users.
Anything compatible with the JAX version listed in the pyproject.toml
of Diffrax, or the pyproject.toml
of its dependencies.
As a practical matter, right now I think that means >=0.4.23,!=0.4.27
.
Make sure people have a new-ish version of Python and use virtual environments. Just do something like
python3 -m venv env
source env/bin/activate
pip3 install jax jaxlib diffrax
and you'll be fine. You will also not have to worry about which version of whatever else people have on their PCs.
FWIW, pip install
does seem to be more reliable than conda
- this is also my experience from helping people use diffrax for coursework.
I recently ran a hack session for a code that uses diffrax, but had some trouble getting some participants set up with JAX/diffrax, especially those using M1 or M3 macs (M2 seemed to work fine).
I have found it safest to install JAX using conda when using apple silicon, and so that's the first step I suggested. However, there doesn't seem to be a conda installation of diffrax available, so we needed to pip install diffrax. But diffrax would then uninstall the version of JAX we installed with conda, and reinstall a different version with pip. This caused participants with this hardware to get JAX-related errors when trying to run JAX code (
This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.
--this is usually the issue I work around by installing JAX with conda). Some participants were able to then uninstall the version of JAX installed by diffrax, and reinstall it again with conda. Others tried this and got errors in diffrax after doing so (partial stack trace included below).I only have access to an M2 mac and so it's difficult for me to replicate the issue. In fact, it's difficult to replicate the issue at all, because participants had varying levels of success even on similar hardware (we tried many different JAX/python versions for these participants--on M1, downgrading to python 3.11 seemed to help, but did not consistently help on M3). If it's possible to a) not have diffrax install its own JAX with pip, and/or b) install diffrax with conda, and/or c) provide installation best practices for diffrax with apple silicon, I'd hope that would help with some of these headaches down the line.
With
Python 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] on darwin
, one participant with an M3 got the following error after reinstalling JAX with conda: