jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

ImportError: cannot import name 'index' from 'jax.ops' #10293

Closed Tusay closed 2 years ago

Tusay commented 2 years ago

I'm trying to work through the tutorial here: http://secondearths.sakura.ne.jp/exojax/tutorials/optimize_spectrum_JAXopt.html

And when I get to this block of code, I get an error. I'm using a GPU on google colab and I've confirmed that it's running.

from exojax.spec.lpf import xsmatrix
from exojax.spec.exomol import gamma_exomol
from exojax.spec.hitran import SijT, doppler_sigma, gamma_natural, gamma_hitran
from exojax.spec.hitrancia import read_cia, logacia
from exojax.spec.rtransfer import rtrun, dtauM, dtauCIA, nugrid
from exojax.spec import planck, response
from exojax.spec.lpf import xsvector
from exojax.spec import molinfo
from exojax.utils.constants import RJ, pc, Rs, c

pip show jax outputs: Name: jax Version: 0.3.4 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /usr/local/lib/python3.7/dist-packages Requires: typing-extensions, scipy, absl-py, opt-einsum, numpy Required-by: numpyro, jaxopt, exojax

Full error messages/tracebacks:


ImportError Traceback (most recent call last) in () 4 from exojax.spec.hitrancia import read_cia, logacia 5 from exojax.spec.rtransfer import rtrun, dtauM, dtauCIA, nugrid ----> 6 from exojax.spec import planck, response 7 from exojax.spec.lpf import xsvector 8 from exojax.spec import molinfo

/usr/local/lib/python3.7/dist-packages/exojax/spec/init.py in () 15 ) 16 ---> 17 from exojax.spec.autospec import ( 18 AutoXS, 19 AutoRT,

/usr/local/lib/python3.7/dist-packages/exojax/spec/autospec.py in () 1 """Automatic Opacity and Spectrum Generator.""" 2 import time ----> 3 from exojax.spec import defmol, defcia, moldb, contdb, planck, molinfo, lpf, dit, modit, initspec, response 4 from exojax.spec.opacity import xsection 5 from exojax.spec.hitran import SijT, doppler_sigma, gamma_natural, gamma_hitran, normalized_doppler_sigma

/usr/local/lib/python3.7/dist-packages/exojax/spec/dit.py in () 10 from jax.lax import scan 11 from exojax.spec.ditkernel import fold_voigt_kernel ---> 12 from jax.ops import index as joi 13 from exojax.spec.atomll import padding_2Darray_for_each_atom 14 from exojax.spec.rtransfer import dtauM

ImportError: cannot import name 'index' from 'jax.ops' (/usr/local/lib/python3.7/dist-packages/jax/ops/init.py)

I'm not sure how to resolve this issue.

jakevdp commented 2 years ago

Thanks for the report! This was deprecated in JAX version 0.2.22 and removed in version 0.3.2 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-032-march-16-2022)

Instead of jax.ops.index, we recommend jnp.index_exp (which is essentially identical).

If you're depending on another project that is attempting to import this, you'll have to downgrade to JAX 0.3.1 or older until the package using it can be updated.