Open cool-RR opened 1 month ago
Can you try jax 0.4.31 which is just released?
I upgraded to 0.4.31
and now I'm getting different warnings:
C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_run_evaluation_rollout': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ae9e0a0f7ccd64b22e7749a155df1b10cd93374b9af8ff3154d62baafe614222-atime'
warnings.warn(
test_polina/test_dumb_on_dumb.py::test_dumb_on_dumb
C:\Users\Administrator\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:663: UserWarning: Error reading persistent compilation cache entry for 'jit_evaluate_population_vs_fixed_policies': PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_evaluate_population_vs_fixed_policies-f84b39d478dddc89b532239150355f5d699747470f65562e377b663be8e93bcc-atime'
warnings.warn(
There were about 10 more, and all the filenames end with -atime
, none end with -cache
.
I enabled tracebacks and got this:
polina\cling.py:70: in train
for population, evaluation, outer_analytica_rows, outer_fuzzy_rows in iterator:
polina\pola_algorithm.py:149: in train
yield population, self.env.evaluate(seed, population, i_epoch=0), (), ()
polina\enving\evaluating_enving.py:98: in evaluate
step_aux = self.run_evaluation_rollout(seed, population)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:332: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:190: in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:2739: in bind
return self.bind_with_trace(top_trace, args, params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:433: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\core.py:939: in process_primitive
return primitive.impl(*tracers, **params)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1730: in _pjit_call_impl
return xc._xla.pjit(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1712: in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\pjit.py:1642: in _pjit_call_impl_python
).compile(compile_options)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2295: in compile
executable = UnloadedMeshExecutable.from_hlo(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2807: in from_hlo
xla_executable = _cached_compilation(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\interpreters\pxla.py:2621: in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:345: in compile_or_get_cached
retrieved_executable, retrieved_compile_time = _cache_read(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compiler.py:658: in _cache_read
return compilation_cache.get_executable_and_time(
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\compilation_cache.py:215: in get_executable_and_time
executable_and_time = cache.get(cache_key)
..\..\..\.venvs\polina_env\Lib\site-packages\jax\_src\lru_cache.py:112: in get
atime_path.write_bytes(timestamp)
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\abstract_path.py:187: in write_bytes
with self.open('wb') as f:
..\..\..\.venvs\polina_env\Lib\site-packages\etils\epath\gpath.py:255: in open
gfile = self._backend.open(self._path_str, mode)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <etils.epath.backend._OsPathBackend object at 0x0000016BA4805220>
path = 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
mode = 'wb'
def open(
self,
path: PathLike,
mode: str,
) -> typing.IO[Union[str, bytes]]:
if 'b' in mode:
encoding = None
else:
encoding = 'utf-8'
> return open(path, mode, encoding=encoding)
E PermissionError: [Errno 13] Permission denied: 'J:\\jaxxy\\jit_run_evaluation_rollout-ff90452405b81e57fa45c6153d6b90b6fd7afae25904fd1330a052f35df7f531-atime'
Description
I've got a pytest test suite and I've recently started running it with
-n 3
, usingpytest-xdist
, so it'll run on 3 processes in parallel. I sometimes get a warning like this one:When it does happen, the tests run a lot slower. I'm guessing sometimes the tests attempt to manipulate the cache at the same time and clash with each other, causing the cache to not be used and triggering compilation. I can't reliably reproduce this problem.
System info (python version, jaxlib version, accelerator, etc.)