Closed maxencefaldor closed 4 months ago
Me too
I do get those as well sometimes. They are XLA/CUDA issues but I think they are usually just warnings. Does the code continue to run after compilation is done?
Otherwise, you can try adding --xla_gpu_autotune_level=0
to the XLA_FLAGS in dreamerv3/jaxagent.py
or try updating JAX to the newest vesion.
Yes, the code continues to run but I wonder if it runs slower than it should. I am already using JAX last version (0.4.26).
I noticed a decreased in fps after this update but I am not sure if this is related to those errors/warnings.
edit: I didn't mean to close the issue, sorry about that.
The decreased FPS are probably because the model architecture is different now.
You can save some flops and memory by setting --'(enc|dec).simple.outer' False --'(enc|dec).simple.mults' 1,2,3,4
which gives very similar performance unless the images are visually more complex.
Closing this issue because those are XLA/CUDA warnings unrelated to Dreamer itself. I would also be happy if they got fixed.
For anyone interested, changing jax.compute_dtype
from bfloat16
to float32
fixed the warnings for me.
However, using float32
is a little bit slower than bfloat16
going from 500 to 430 fps/policy on my hardware.
Did anyone encounter the same issue updating to the last version of DreamerV3?