Closed cottrell closed 41 minutes ago
ml_dtypes
is an explicit requirement: https://github.com/jax-ml/jax/blob/587832f2950ef241e0fe1af837e22f52b8717ad9/setup.py#L57
However, with NumPy 2.0 you need ml_dtypes
version 0.4 or newer.
Perhaps we can bump the minimum now.
We can bump it to 0.4.0, but not 0.5.0, because tensorflow still pins ml_dtypes<0.5.0
.
>= 0.4.0
.
Description
I upgraded jax and noticed a numpy 2.0 related error.
It ended up being ml-dtypes not being updated.
If this is now a requirement I think it should be in the toml file or requirements no? And then it should get upgraded automtically.
System info (python version, jaxlib version, accelerator, etc.)