jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Persistent compilation cache does not work #21067

Open neel04 opened 6 months ago

neel04 commented 6 months ago

Description

The persistent compilation cache simply doesn't work - it used to work well with older versions of jax but it seems some breaking changes have occurred in the past weeks.

The problem is that the compilation_cache folder is never created, and I can confirm from the lack of speedup that jax is definitely not using the persistent cache.

Reproduce:

import jax
jax.config.update("jax_compilation_cache_dir", './jax-cache')

import os

@jax.jit
def some_op(A, B):
    return (A @ B) * (A + B)

A = jax.numpy.ones((128,))
B = jax.numpy.zeros((128,))

some_op(A, B)

print(os.listdir('./')) # no folder gets created

I don't think its an environment issue - I have a docker image as well that can reproduce it.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26 jaxlib: 0.4.25 numpy: 1.26.2 python: 3.10.12 (main, Jul 5 2023, 15:02:25) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Neels-MacBook-Air.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:59:33 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T8112', machine='arm64')

rajasekharporeddy commented 6 months ago

Hi @neel04

I tested the provided reproducible code on my Mac-book Pro with M1 Pro chip using jax versions 0.4.26 and 0.4.27.dev20240503 and the corresponding jaxlib versions 0.4.26 and 0.4.27.dev20240503, respectively. In both the cases, a folder named 'jax-cache' was created. Please find the below screenshot for reference.

image

jax.print_environment_info(): jax: 0.4.27.dev20240503 jaxlib: 0.4.27.dev20240503 numpy: 1.26.4 python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')

Could you please verify with jaxlib version 0.4.26 along with jax 0.4.26 or with JAX nightly version and let us know.

Thank you.

neel04 commented 6 months ago

Yep, upgrading jaxlib from 0.4.25 -> 0.4.26 works locally. So this is definitely a version issue.

However, on my docker image on a TPU v3-8 I have libs:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 (main, Apr 24 2024, 11:58:32) [GCC 12.2.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='f996f75a635a', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')

and it doesn't work. I can rebuild it - this image hasn't been updated in a few weeks, but I'm not really sure where exactly the problem lies.