Closed SumeetBatra closed 9 months ago
Hi @SumeetBatra , sorry for the very late reply here. Indeed the first time the code is run (in a jax jit), the value in self.init_pose
gets compiled. Subsequent changes to self.init_pose
won't affect the compiled code. See here for more info.
To set a specific pose in the reset, consider passing in init_pose
to the reset
function (i.e. make init_pose
a static arg). Alternatively, re-jit the reset
function after changing init_pose
Hi folks,
I am using brax 0.9.4 and jax 0.4.23. I have a wrapper that I'm using to manually set the initial xy position of the agent with the following methods:
Essentially replacing the default torso xy-pose with the desired one by calling
set_initial_state
followed byenv.reset()
. This works the first time I run it, but subsequent calls result in the same pose being set as the first time I calledset_initial_state
.With the debugger, I am able to step through all the way back to this wrapper the first time
env.reset()
is called, but after that I can't step in pastVectorGymWrapper
'sreset
. My guess is that this has something to do withVectorGymWrapper
being jitted? But I'm not sure. Any help would be appreciated!