jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.29k stars 2.78k forks source link

Should ml-dtypes be in the requirements? #24461

Closed cottrell closed 41 minutes ago

cottrell commented 2 hours ago

Description

I upgraded jax and noticed a numpy 2.0 related error.

It ended up being ml-dtypes not being updated.

If this is now a requirement I think it should be in the toml file or requirements no? And then it should get upgraded automtically.

System info (python version, jaxlib version, accelerator, etc.)

import jax; jax.print_environment_info()
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.2
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='bleepblop', release='6.5.0-44-generic', version='#44-Ubuntu SMP PREEMPT_DYNAMIC Fri Jun  7 15:10:09 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Oct 22 16:53:01 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 2070        Off |   00000000:01:00.0  On |                  N/A |
| 32%   47C    P2             45W /  175W |    2013MiB /   8192MiB |     29%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      4822      G   /usr/lib/xorg/Xorg                           1027MiB |
|    0   N/A  N/A      5029      G   /usr/bin/gnome-shell                          162MiB |
|    0   N/A  N/A      5707      G   /usr/libexec/xdg-desktop-portal-gnome          35MiB |
|    0   N/A  N/A     71564      G   ...seed-version=20241020-180137.275000        471MiB |
|    0   N/A  N/A     71943      G   ...erProcess --variations-seed-version        142MiB |
|    0   N/A  N/A   3064268      C   .../anaconda3/envs/3.12/bin/python3.12         96MiB |
+-----------------------------------------------------------------------------------------+
jakevdp commented 2 hours ago

ml_dtypes is an explicit requirement: https://github.com/jax-ml/jax/blob/587832f2950ef241e0fe1af837e22f52b8717ad9/setup.py#L57

However, with NumPy 2.0 you need ml_dtypes version 0.4 or newer.

hawkinsp commented 2 hours ago

Perhaps we can bump the minimum now.

jakevdp commented 2 hours ago

We can bump it to 0.4.0, but not 0.5.0, because tensorflow still pins ml_dtypes<0.5.0.

jakevdp commented 2 hours ago

24463 bumps the ml_dtypes requirement to >= 0.4.0.