bmorris3 / shone

Radiative transfer in JAX
https://shone.readthedocs.io/
MIT License
4 stars 5 forks source link

Remove maxpin on jax #28

Open bmorris3 opened 1 month ago

bmorris3 commented 1 month ago

jax v0.4.30 breaks something in shone, and it's unclear what. The traceback can be found in the docs build here:

reading sources... [100%] shone/installation

/home/docs/checkouts/readthedocs.org/user_builds/shone/checkouts/latest/docs/shone/examples/transmission.rst:258: WARNING: Exception occurred in plotting transmission-3
 from /home/docs/checkouts/readthedocs.org/user_builds/shone/checkouts/latest/docs/shone/examples/transmission.rst:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/shone/envs/latest/lib/python3.11/site-packages/jax/_src/api_util.py", line 287, in _argnums_partial
    args = [next(fixed_args_).val if x is sentinel else x for x in args]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/shone/envs/latest/lib/python3.11/site-packages/jax/_src/api_util.py", line 287, in <listcomp>
    args = [next(fixed_args_).val if x is sentinel else x for x in args]
            ^^^^^^^^^^^^^^^^^
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/shone/envs/latest/lib/python3.11/site-packages/matplotlib/sphinxext/plot_directive.py", line 552, in _run_code
    exec(code, ns)
  File "<string>", line 38, in <module>
RuntimeError: generator raised StopIteration
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. [docutils]

This needs to be fixed for compatibility with newer versions of jax. In the mean time, I've pinned the max version of jax and jaxlib to v0.4.29 in https://github.com/bmorris3/shone/commit/cfd9047f34c75b2e035a581b618af99e2f8ff8f2.