Just a heads up, with jax==0.4.26 and flax==0.7.* I was getting errors when importing sax installed recently with pip install --upgrade sax
>>> import sax
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/lib/python3.11/site-packages/sax/__init__.py", line 15, in <module>
from flax.core.frozen_dict import FrozenDict as FrozenDict
File "/opt/homebrew/lib/python3.11/site-packages/flax/__init__.py", line 19, in <module>
from .configurations import (
File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 93, in <module>
flax_filter_frames = define_bool_state(
^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 42, in define_bool_state
return jax_config.define_bool_state('flax_' + name, default, help)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Config' object has no attribute 'define_bool_state'
It was fixed when pip install --upgrade flax installed 0.8.*.
Not sure if this is something you might need to know for setting requirements but thought I'd let you know / create a paper trail in case others see this, feel free to close
Just a heads up, with
jax==0.4.26
andflax==0.7.*
I was getting errors when importingsax
installed recently withpip install --upgrade sax
It was fixed when
pip install --upgrade flax
installed0.8.*
.Not sure if this is something you might need to know for setting requirements but thought I'd let you know / create a paper trail in case others see this, feel free to close