google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.16k stars 147 forks source link

Optax update breaks due to jax.Array migration #35

Closed colbybanbury closed 10 months ago

colbybanbury commented 1 year ago

A recent change in Optax is breaking due to jax.Array Migration

Optax commit

Error:

File "bv_venv/lib/python3.8/site-packages/optax/_src/clipping.py" in <module>
    ) -> Tuple[chex.Array, jax.Array]:

AttributeError: module 'jax' has no attribute 'Array'

I believe the fix would either be to pin optax to a release. e.g. git+https://github.com/deepmind/optax.git@v0.1.4 here

or increase the minimum Jax version e.g. pip install "jax[tpu]>=0.4.1" here

colbybanbury commented 1 year ago

I did some basic tests with the v0.1.4 optax release and confirmed it works, but someone with more vision on the whole repo would know the preferred fix.

lucasb-eyer commented 11 months ago

Very late answer, but we'll push some larger updates to the codebase within the next 1-2 months, that will also fix this.

akolesnikoff commented 10 months ago

Should be fixed at head now.