Closed willthibault closed 1 month ago
Just so I understand - you're training in MJX and evaluating the policy in C MuJoCo (presumably via the python bindings) and seeing unstable physics? Is it possible that there's something else different between the training and eval environments, possibly the initial state? Are you hitting some terminating condition that you're ignoring during the eval? What does the video look like leading up to the instability?
Feel free to post a colab.
I am training in MJX then evaluating the policy in python the same as the this part of the Barkour colab. Similar to the Barkour colab the training and eval environments are identical, including the initial state. I monitor the termination condition when visualizing the policy and it is not terminating, but simply producing the NaN control value. Up to instability (generally one frame), the initial state based on the keyframe I'm using is set and looks correct then it goes to NaN control values.
Just to reiterate, I am using a pipeline that very closely mimics the Barkour colab. I have used this pipeline for many problems and am reasonably certain that it works successfully. In the past when I had NaNs it would occur during training due to an unstable simulation or could be resolved with jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
. In this case, that does not solve the issue as it only occurs when generating the frames for the video. I'm wondering if this happens when jitting the functions for inference.
I'll see if I can put together a colab to reproduce this issue, but it does involve a reasonable train time and I may not able to open source this just yet (hopefully soon though). Are there any additional checks I could perform or logs I can provide?
OK, if you really think it's happening somewhere in the inference function, that's a bit suprising to me, but the good news is that's a pretty small surface area to search - really only a few hundred lines of code or so. You can try removing the @jit so you can trace through, or binary search for the nan with jax.debug.print
- I bet you can find it that way. Let us know!
Thanks for the suggestion! I'll spend some time tracking down the error and share what I find.
any luck?
@i1Cps,
I still need to investigate this some more, but I can share what I have figured out so far.
First, make sure that your simulation is stable. Simulations with features like many contacts, unrealistically high control actions and highly constrained systems (ex. the equality constraints creating a loop as mentioned above) can become unstable easily. This was not the case for me.
What produced NaNs:
What did not produce NaNs:
I will work on tracking this down more in the coming weeks, but hopefully this helps!
I have also seen scenarios where a 4090 produces unstable physics where an A100 does not, given the exact same MJX environment and python version. I have yet to track down why, but it probably has something to do with matmul precision defaults.
Indeed we find that setting one of
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
jax.config.update("jax_enable_x64", True)
helps on RTX devices. jax_enable_x64
kills perf, but usually the first matmul precision is enough to get rid of NaNs and keep decent performance. Closing this issue as a discussion thread.
Hello,
I have been encountering an issue where my training runs error free and learns well, but then control values of NaN are generated at inference when collecting a trajectory to make a video of the task.
I am currently using the following lines to improve the precision and debug NaNs:
The error generated at inference from MuJoCo is:
The error from the inference is the following:
I'm not sure how the training could work well and then at inference generate NaNs as a NaN value in training would have thrown an error. My model does include a decent number of contacts and two equality constraints that create a loop constraint, but the model appears stable in MuJoCo and during the training.
I do have a work around to fix the issue, which is increasing to 64 bit precision:
My main concern here is that the training time increases drastically along with the GPU memory required. Training for 1 million steps went from 1min 42s to 3 min 42s (on an RTX 4090) and the GPU memory to allocate went from ~20 GB to ~46 GB. Excluding some contacts allowed me to reduce this to 2 min 56s and back under the 24 GB of memory to continue using this GPU.
My pipeline mirrors the Barkour training and inference pipeline very closely.
Some model details that may help (also very similar to Barkour model): training dt = 0.02 model.opt.timestep = 0.005 integrator = Euler (though I did try the RK4 and it didn't help) eulerdamp = disable iterations = 1 ls_iterations = 5
I am using MuJoCo/MJX = 3.1.6 and Brax = 0.9.4 (though I also tried 0.10.5 and same the same issues).
Is there a reason that I am encountering this behaviour when performing the inference?
Thanks!