Open schneimo opened 1 year ago
Sometimes the trace can take a while with old GPUs, I've waited around 10 minutes for a TitanX workstation before.
You can try making the CNN smaller to see if that speeds up compilation time. You can also try incrementally increasing the resolution and check if the trace time increases.
Thanks.
I am not sure if time and compute power is really the problem. Even after 24 hours, it did not trace on an A100. But I will test how tracing time increases with increasing image resolution and report my findings here.
I worked a little bit more on this topic and found out that the train
function of class Agent
is called completely since when it is decorated with an additional timer, the timer gets executed.
Furthermore, I tracked the problem a little bit more down and it seems to arise in the try
block of pure
inside the Ninjax module.
https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/ninjax.py#L60-L101
Hi Danijar,
I am currently trying to use higher image resolutions like 256x256 for Dreamer. By simply changing the resolution e.g. for DM control suite, JAX is not able to trace/compile the training function anymore:
But instead of an error the program seems to be stuck at/after the point where it tries to trace the training function with JAX:
I have tested this on a V100 and an A100. Both with the same result. With smaller resolutions (e.g. 128x128 or 64x64) this works of course.
I tried to debug this but I am not really able to track this down inside Ninjax or Jax.
Thanks a lot for your help!