google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.82k stars 2.73k forks source link

FileNotFoundError when using cache in parallel #22718

Open cool-RR opened 1 month ago

cool-RR commented 1 month ago

Description

I've got a pytest test suite and I've recently started running it with -n 3, using pytest-xdist, so it'll run on 3 processes in parallel. I sometimes get a warning like this one:

test_polina/test_golden_runs.py::test_ipd_three_agent_golden_run
  C:\Program Files\Python312\Lib\site-packages\jax\_src\compiler.py:688: UserWarning: Error writing persistent compilation cache entry for 'jit_scan_outer_bigs': FileNotFoundError: [WinError 2] The system cannot find the file specified: 'J:\\jaxxy\\_temp_jit_scan_outer_bigs-43cfe1699669be4ee80585eba899a893fea7eb948b6e707293e0235d7fdbffe0' -> 'J:\\jaxxy\\jit_scan_outer_bigs-43cfe1699669be4ee80585eba899a893fea7eb948b6e707293e0235d7fdbffe0'
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

When it does happen, the tests run a lot slower. I'm guessing sometimes the tests attempt to manipulate the cache at the same time and clash with each other, causing the cache to not be used and triggering compilation. I can't reliably reproduce this problem.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.1 (tags/v3.12.1:2305ca5, Dec  7 2023, 22:03:25) [MSC v.1937 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Turing', release='10', version='10.0.19045', machine='AMD64')

$ nvidia-smi
Mon Jul 29 15:48:53 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 456.71       Driver Version: 456.71       CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GT 710     WDDM  | 00000000:01:00.0 N/A |                  N/A |
| 50%   52C    P0    N/A /  N/A |   1627MiB /  2048MiB |     N/A      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
ayaka14732 commented 1 month ago

Can you try jax 0.4.31 which is just released?

cool-RR commented 1 month ago

I upgraded to 0.4.31 and now I'm getting different warnings:

  C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_run_evaluation_rollout': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ae9e0a0f7ccd64b22e7749a155df1b10cd93374b9af8ff3154d62baafe614222-atime'
    warnings.warn(

test_polina/test_dumb_on_dumb.py::test_dumb_on_dumb
  C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_evaluate_population_vs_fixed_policies': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_evaluate_population_vs_fixed_policies-f84b39d478dddc89b532239150355f5d699747470f65562e377b663be8e93bcc-atime'
    warnings.warn(

There were about 10 more, and all the filenames end with -atime, none end with -cache.

I enabled tracebacks and got this:

polina\cling.py:70: in train
    for population, evaluation, outer_analytica_rows, outer_fuzzy_rows in iterator:
polina\pola_algorithm.py:149: in train
    yield population, self.env.evaluate(seed, population, i_epoch=0), (), ()
polina\enving\evaluating_enving.py:98: in evaluate
    step_aux = self.run_evaluation_rollout(seed, population)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:332: in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:190: in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:2739: in bind
    return self.bind_with_trace(top_trace, args, params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:433: in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:939: in process_primitive
    return primitive.impl(*tracers, **params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1730: in _pjit_call_impl
    return xc._xla.pjit(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1712: in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1642: in _pjit_call_impl_python
    ).compile(compile_options)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2295: in compile
    executable = UnloadedMeshExecutable.from_hlo(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2807: in from_hlo
    xla_executable = _cached_compilation(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2621: in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:345: in compile_or_get_cached
    retrieved_executable, retrieved_compile_time = _cache_read(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:658: in _cache_read
    return compilation_cache.get_executable_and_time(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compilation_cache.py:215: in get_executable_and_time
    executable_and_time = cache.get(cache_key)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\lru_cache.py:112: in get
    atime_path.write_bytes(timestamp)
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\abstract_path.py:187: in write_bytes
    with self.open('wb') as f:
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\gpath.py:255: in open
    gfile = self._backend.open(self._path_str, mode)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <etils.epath.backend._OsPathBackend object at 0x0000016BA4805220>
path = 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
mode = 'wb'

    def open(
        self,
        path: PathLike,
        mode: str,
    ) -> typing.IO[Union[str, bytes]]:
      if 'b' in mode:
        encoding = None
      else:
        encoding = 'utf-8'
>     return open(path, mode, encoding=encoding)
E     PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'