Open bekleyis95 opened 1 month ago
The problem comes from Model::random_act. This function is being called when random_timesteps is greater than the current timestep. In isaaclab environments, the action spaces are defined in the bound of [-inf, inf] and Model::random_act tries to create a torch uniform distribution with the bounds of [-inf,inf] which produces NaN action values and that's why the objects are disappearing from the scene.
Quick fix for me was to set random_timsteps parameter to 0. But I don't know what is the permanent fix for this issue.
Description
When I ran the example scripts under docs/source/isaaclab for _torch_antppo.py and _jax_antppo.py the training starts as expected without an issue but for _torch_antsac.py, scene loads and right before training starts ant objects are disappearing preventing learning. For _jax_antsac.py the program crashes with the following error.
Traceback (most recent call last): File "/home/deniz.seven/workspace/skrl/docs/source/examples/isaaclab/jax_ant_sac.py", line 114, in <module> trainer.train() File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/skrl/trainers/jax/sequential.py", line 81, in train self.single_agent_train() File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/skrl/trainers/jax/base.py", line 172, in single_agent_train actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/skrl/agents/jax/sac/sac.py", line 301, in act return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/flax/linen/module.py", line 701, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/flax/linen/module.py", line 1233, in _call_wrapped_method y = run_fun(self, *args, **kwargs) File "/home/deniz.seven/miniconda3/envs/isaaclab/lib/python3.10/site-packages/skrl/models/jax/base.py", line 310, in random_act actions = np.random.uniform(low=self.action_space.low[0], high=self.action_space.high[0], size=(inputs["states"].shape[0], self.num_actions)) File "numpy/random/mtrand.pyx", line 1156, in numpy.random.mtrand.RandomState.uniform OverflowError: Range exceeds valid bounds
https://github.com/user-attachments/assets/2eedf1a8-7fa3-491a-9384-f146d412b0cc
What skrl version are you using?
1.2.0
What ML framework/library version are you using?
JAX and torch
Additional system information
Python 3.10.14, Linux