kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

`jax.interpreters.pxla` has no attribute `ShardedDeviceArray` #138

Closed Qazalbash closed 11 months ago

Qazalbash commented 11 months ago

I was simply playing with flowMC, up until today evening it was running. Now it is giving this error. I have checked the indicated file on JAX official GitHub repository, and there is no ShardedDeviceArray. Please look into this issue and at the moment what should I do to make it start working.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-6-b834d15ac679>](https://localhost:8080/#) in <cell line: 16>()
     14 from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
     15 from flowMC.sampler.MALA import MALA
---> 16 from flowMC.sampler.Sampler import Sampler
     17 from flowMC.utils.PRNG_keys import initialize_rng_keys
     18 from jax import Array, jit

8 frames
[/usr/local/lib/python3.10/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>
     23 ArrayBatched = jax.interpreters.batching.BatchTracer
     24 ArrayNumpy = np.ndarray
---> 25 ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
     26 # For instance checking, use `isinstance(x, jax.Array)`.
     27 if hasattr(jax, 'Array'):

AttributeError: module 'jax.interpreters.pxla' has no attribute 'ShardedDeviceArray'
Qazalbash commented 11 months ago

I want to clarify that I was feeling afraid and in a hurry when I wrote about my issue. After some investigation, I discovered that the problem was related to an older version of chex in colab. Thank you for your patience and understanding as I worked through this issue.

kazewong commented 11 months ago

No worries @Qazalbash ! Thanks for noting this, so we can point this issue to people in the future