google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.19k stars 819 forks source link

Any way to convert a DynamicJaxprTracer variable back into a numpy or print out its values? #1373

Closed AlexS28 closed 9 months ago

AlexS28 commented 9 months ago

Hi,

I'm a student using MuJoCo (MJX) with brax for training a RL policy (currently modifying the Barkour tutorial from here: https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb). Not sure if relevant, but I am not using collab, instead I am using an IDE to debug the code. I did look at https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html, but these solutions didn't work for me.

My question is as follows:

  1. I am trying to debug line by line the Barkour tutorial code in order for me to understand what's happening so I can modify it and apply it to my own problem and robot. However, debugging has been challenging because as we convert the (for instance a numpy array) into the jax version of numpy, the debugger and also when I try to print out the variables, turn into the type DynamicJaxprTracer. Once it's in this type and if I try to print out one of the variables, it will look like this instead of giving me the actual value stored within: Traced<ShapedArray(float32[465])>with<DynamicJaxprTrace(level=1/0)>. Is there any way for me to print out or convert this variable either back to a numpy or in another type where I can at least see the values inside a variable of type DynamicJaxprTracer?

Thank you very much for any help or clarification on this, Sincerely, Alex

btaba commented 9 months ago

Hi @AlexS28 , you can use https://jax.readthedocs.io/en/latest/_autosummary/jax.disable_jit.html for debugging, or you can return early from a function to get its intermediate value. Hope this helps!

AlexS28 commented 9 months ago

thanks that works