jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.03k stars 2.75k forks source link

FileNotFoundError when using cache in parallel #22718

Open cool-RR opened 1 month ago

cool-RR commented 1 month ago

Description

I've got a pytest test suite and I've recently started running it with -n 3, using pytest-xdist, so it'll run on 3 processes in parallel. I sometimes get a warning like this one:

test_polina/test_golden_runs.py::test_ipd_three_agent_golden_run
  C:\Program Files\Python312\Lib\site-packages\jax\_src\compiler.py:688: UserWarning: Error writing persistent compilation cache entry for 'jit_scan_outer_bigs': FileNotFoundError: [WinError 2] The system cannot find the file specified: 'J:\\jaxxy\\_temp_jit_scan_outer_bigs-43cfe1699669be4ee80585eba899a893fea7eb948b6e707293e0235d7fdbffe0' -> 'J:\\jaxxy\\jit_scan_outer_bigs-43cfe1699669be4ee80585eba899a893fea7eb948b6e707293e0235d7fdbffe0'
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

When it does happen, the tests run a lot slower. I'm guessing sometimes the tests attempt to manipulate the cache at the same time and clash with each other, causing the cache to not be used and triggering compilation. I can't reliably reproduce this problem.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.1 (tags/v3.12.1:2305ca5, Dec  7 2023, 22:03:25) [MSC v.1937 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Turing', release='10', version='10.0.19045', machine='AMD64')

$ nvidia-smi
Mon Jul 29 15:48:53 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 456.71       Driver Version: 456.71       CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GT 710     WDDM  | 00000000:01:00.0 N/A |                  N/A |
| 50%   52C    P0    N/A /  N/A |   1627MiB /  2048MiB |     N/A      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
ayaka14732 commented 1 month ago

Can you try jax 0.4.31 which is just released?

cool-RR commented 1 month ago

I upgraded to 0.4.31 and now I'm getting different warnings:

  C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_run_evaluation_rollout': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ae9e0a0f7ccd64b22e7749a155df1b10cd93374b9af8ff3154d62baafe614222-atime'
    warnings.warn(

test_polina/test_dumb_on_dumb.py::test_dumb_on_dumb
  C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_evaluate_population_vs_fixed_policies': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_evaluate_population_vs_fixed_policies-f84b39d478dddc89b532239150355f5d699747470f65562e377b663be8e93bcc-atime'
    warnings.warn(

There were about 10 more, and all the filenames end with -atime, none end with -cache.

I enabled tracebacks and got this:

polina\cling.py:70: in train
    for population, evaluation, outer_analytica_rows, outer_fuzzy_rows in iterator:
polina\pola_algorithm.py:149: in train
    yield population, self.env.evaluate(seed, population, i_epoch=0), (), ()
polina\enving\evaluating_enving.py:98: in evaluate
    step_aux = self.run_evaluation_rollout(seed, population)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:332: in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:190: in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:2739: in bind
    return self.bind_with_trace(top_trace, args, params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:433: in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:939: in process_primitive
    return primitive.impl(*tracers, **params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1730: in _pjit_call_impl
    return xc._xla.pjit(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1712: in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1642: in _pjit_call_impl_python
    ).compile(compile_options)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2295: in compile
    executable = UnloadedMeshExecutable.from_hlo(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2807: in from_hlo
    xla_executable = _cached_compilation(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2621: in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:345: in compile_or_get_cached
    retrieved_executable, retrieved_compile_time = _cache_read(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:658: in _cache_read
    return compilation_cache.get_executable_and_time(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compilation_cache.py:215: in get_executable_and_time
    executable_and_time = cache.get(cache_key)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\lru_cache.py:112: in get
    atime_path.write_bytes(timestamp)
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\abstract_path.py:187: in write_bytes
    with self.open('wb') as f:
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\gpath.py:255: in open
    gfile = self._backend.open(self._path_str, mode)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <etils.epath.backend._OsPathBackend object at 0x0000016BA4805220>
path = 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
mode = 'wb'

    def open(
        self,
        path: PathLike,
        mode: str,
    ) -> typing.IO[Union[str, bytes]]:
      if 'b' in mode:
        encoding = None
      else:
        encoding = 'utf-8'
>     return open(path, mode, encoding=encoding)
E     PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
LunarLanding commented 6 days ago

Hi, I hit on a similar problem and I think I found the cause and solution. I think the key thing is that the temporary file the cache writes to in normal filesystems does not have a per-process unique name: https://github.com/google/jax/blob/fed7efd73003988282744b2f649df493b808c781/jax/_src/gfile_cache.py#L37-L55 All processes will attempt to open the same tmp_path = self._path / f"_temp_{key}", and any but the first will error. A possible fix is suffixing the temporary file with the hostname of the machine and process id.

I include the hostname, since the usecase that brought me here is reusing the same cache for processes in possibly the same node but different MPI groups, but also processes in different nodes. In which case it is also important to differentiate based on hostname.

In NFS, a more complex solution and possibly unnecessary operation is doing NFS locking like described here (see D) , or done in flufl.lock.

ayaka14732 commented 6 days ago

@LunarLanding The file you mentioned has been removed in recent version of JAX. Can you try again with the latest version?

LunarLanding commented 6 days ago

@ayaka14732 I see, I mistakenly navigated to old code. Here’s the code in version 0.4.32, which works for a single machine if you have eviction enabled, but I’m not sure will work for NFS ( since SoftFileLock does not seem to follow all the steps described in the link above regarding NFS; compare with flufl.lock implementation) https://github.com/google/jax/blob/1594d2f30fdbfebf693aba4a2b264e4a3e52acc6/jax/_src/lru_cache.py#L120-L160

LunarLanding commented 20 hours ago

Numba had to deal with the same issue. Instead of making a key that incorporates hostname/pid, they just open a temp file with a name built from a uuid: https://github.com/numba/numba/blob/301ba23116864ef66a2caccfab7c7123e7bcbcfc/numba/core/caching.py#L599-L616