Open neel04 opened 6 months ago
Hi @neel04
I tested the provided reproducible code on my Mac-book Pro with M1 Pro chip using jax versions 0.4.26 and 0.4.27.dev20240503 and the corresponding jaxlib versions 0.4.26 and 0.4.27.dev20240503, respectively. In both the cases, a folder named 'jax-cache' was created. Please find the below screenshot for reference.
jax.print_environment_info()
:
jax: 0.4.27.dev20240503
jaxlib: 0.4.27.dev20240503
numpy: 1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
Could you please verify with jaxlib version 0.4.26 along with jax 0.4.26 or with JAX nightly version and let us know.
Thank you.
Yep, upgrading jaxlib
from 0.4.25
-> 0.4.26
works locally. So this is definitely a version issue.
However, on my docker image on a TPU v3-8
I have libs:
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.11.9 (main, Apr 24 2024, 11:58:32) [GCC 12.2.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='f996f75a635a', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')
and it doesn't work. I can rebuild it - this image hasn't been updated in a few weeks, but I'm not really sure where exactly the problem lies.
Description
The persistent compilation cache simply doesn't work - it used to work well with older versions of jax but it seems some breaking changes have occurred in the past weeks.
The problem is that the
compilation_cache
folder is never created, and I can confirm from the lack of speedup that jax is definitely not using the persistent cache.Reproduce:
I don't think its an environment issue - I have a docker image as well that can reproduce it.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26 jaxlib: 0.4.25 numpy: 1.26.2 python: 3.10.12 (main, Jul 5 2023, 15:02:25) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Neels-MacBook-Air.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:59:33 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T8112', machine='arm64')