Open cool-RR opened 1 month ago
Can you try jax 0.4.31 which is just released?
I upgraded to 0.4.31
and now I'm getting different warnings:
C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_run_evaluation_rollout': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ae9e0a0f7ccd64b22e7749a155df1b10cd93374b9af8ff3154d62baafe614222-atime'
warnings.warn(
test_polina/test_dumb_on_dumb.py::test_dumb_on_dumb
C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_evaluate_population_vs_fixed_policies': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_evaluate_population_vs_fixed_policies-f84b39d478dddc89b532239150355f5d699747470f65562e377b663be8e93bcc-atime'
warnings.warn(
There were about 10 more, and all the filenames end with -atime
, none end with -cache
.
I enabled tracebacks and got this:
polina\cling.py:70: in train
for population, evaluation, outer_analytica_rows, outer_fuzzy_rows in iterator:
polina\pola_algorithm.py:149: in train
yield population, self.env.evaluate(seed, population, i_epoch=0), (), ()
polina\enving\evaluating_enving.py:98: in evaluate
step_aux = self.run_evaluation_rollout(seed, population)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:332: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:190: in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:2739: in bind
return self.bind_with_trace(top_trace, args, params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:433: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:939: in process_primitive
return primitive.impl(*tracers, **params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1730: in _pjit_call_impl
return xc._xla.pjit(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1712: in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1642: in _pjit_call_impl_python
).compile(compile_options)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2295: in compile
executable = UnloadedMeshExecutable.from_hlo(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2807: in from_hlo
xla_executable = _cached_compilation(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2621: in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:345: in compile_or_get_cached
retrieved_executable, retrieved_compile_time = _cache_read(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:658: in _cache_read
return compilation_cache.get_executable_and_time(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compilation_cache.py:215: in get_executable_and_time
executable_and_time = cache.get(cache_key)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\lru_cache.py:112: in get
atime_path.write_bytes(timestamp)
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\abstract_path.py:187: in write_bytes
with self.open('wb') as f:
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\gpath.py:255: in open
gfile = self._backend.open(self._path_str, mode)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <etils.epath.backend._OsPathBackend object at 0x0000016BA4805220>
path = 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
mode = 'wb'
def open(
self,
path: PathLike,
mode: str,
) -> typing.IO[Union[str, bytes]]:
if 'b' in mode:
encoding = None
else:
encoding = 'utf-8'
> return open(path, mode, encoding=encoding)
E PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
Hi, I hit on a similar problem and I think I found the cause and solution.
I think the key thing is that the temporary file the cache writes to in normal filesystems does not have a per-process unique name:
https://github.com/google/jax/blob/fed7efd73003988282744b2f649df493b808c781/jax/_src/gfile_cache.py#L37-L55
All processes will attempt to open the same tmp_path = self._path / f"_temp_{key}"
, and any but the first will error.
A possible fix is suffixing the temporary file with the hostname of the machine and process id.
I include the hostname, since the usecase that brought me here is reusing the same cache for processes in possibly the same node but different MPI groups, but also processes in different nodes. In which case it is also important to differentiate based on hostname.
In NFS, a more complex solution and possibly unnecessary operation is doing NFS locking like described here (see D) , or done in flufl.lock.
@LunarLanding The file you mentioned has been removed in recent version of JAX. Can you try again with the latest version?
@ayaka14732 I see, I mistakenly navigated to old code. Here’s the code in version 0.4.32, which works for a single machine if you have eviction enabled, but I’m not sure will work for NFS ( since SoftFileLock does not seem to follow all the steps described in the link above regarding NFS; compare with flufl.lock implementation) https://github.com/google/jax/blob/1594d2f30fdbfebf693aba4a2b264e4a3e52acc6/jax/_src/lru_cache.py#L120-L160
Numba had to deal with the same issue. Instead of making a key that incorporates hostname/pid, they just open a temp file with a name built from a uuid: https://github.com/numba/numba/blob/301ba23116864ef66a2caccfab7c7123e7bcbcfc/numba/core/caching.py#L599-L616
Description
I've got a pytest test suite and I've recently started running it with
-n 3
, usingpytest-xdist
, so it'll run on 3 processes in parallel. I sometimes get a warning like this one:When it does happen, the tests run a lot slower. I'm guessing sometimes the tests attempt to manipulate the cache at the same time and clash with each other, causing the cache to not be used and triggering compilation. I can't reliably reproduce this problem.
System info (python version, jaxlib version, accelerator, etc.)