google-deepmind / distrax

Apache License 2.0
529 stars 32 forks source link

Incompatibility with JAX 0.4.14 #249

Closed RodrigoAVargasHdz closed 10 months ago

RodrigoAVargasHdz commented 1 year ago

Hi,

I have to downgrade the version of jax to 0.4.13 as the current version (0.4.14) is incompatible with the current version of distrax.

I get the following error.


  File "//lib/python3.10/site-packages/distrax/_src/utils/jittable.py", line 36, in tree_flatten
    switch = list(map(_is_jax_data, leaves))
  File "//python3.10/site-packages/distrax/_src/utils/jittable.py", line 66, in _is_jax_data
    jax.xla.abstractify(x)
  File  #"//python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

Cheers,

rdaems commented 12 months ago

This is already fixed in https://github.com/google-deepmind/distrax/pull/243 but not released yet. I don't have the error when I use the master branch.

GaetanLepage commented 10 months ago

Hello !

Do you have an idea about when the next version of distrax will be released ? Indeed, various fixes related to the changes in the jax API are not included in 0.1.4.

Thank you very much in advance.

sash-a commented 10 months ago

Bumping this thread, any idea of a release date for this fix?

GaetanLepage commented 10 months ago

I think we can close this as distrax release 0.1.5 comes with the fix for those incompatibilities !

Note: The release has been made on GitHub but not on Pypi yet...

hbq1 commented 10 months ago

I've just released a new version on PyPi (https://pypi.org/project/distrax/0.1.5/). Thanks for bumping this thread!