Toni-SM / skrl

Modular reinforcement learning library (on PyTorch and JAX) with support for NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab
https://skrl.readthedocs.io/
MIT License
572 stars 58 forks source link

when using Isaacgym ,I met some problems #227

Open tobottyx opened 6 days ago

tobottyx commented 6 days ago

Description

""" self._distribution = Normal(mean_actions, log_std.exp()) """ skrl.models.torch.gaussian.GaussianMixin 132 lines

Expected parameter loc (Tensor of shape (250, 12)) of distribution Normal(loc: torch.Size([250, 12]), scale: torch.Size([250, 12])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], device='cuda:0',

What skrl version are you using?

1.3.0

What ML framework/library version are you using?

1.13.0+cu117

Additional system information

python3.8 linux20.04

asdfGuest commented 6 days ago

Are you sure your memory_size is equal or grather than rollouts? I have meet same problem using skrl when memory_size is smaller than rollouts.

tobottyx commented 5 days ago

Are you sure your memory_size is equal or grather than rollouts? I have meet same problem using skrl when memory_size is smaller than rollouts.

I confirmed that my memory size is larger than my rollouts, and I found this is caused by the initialization of the memory.

tobottyx commented 5 days ago
    def create_tensor(self,
                      name: str,
                      size: Union[int, Tuple[int], gym.Space, gymnasium.Space],
                      dtype: Optional[torch.dtype] = None,
                      keep_dimensions: bool = True) -> bool:
        """Create a new internal tensor in memory

        The tensor will have a 3-components shape (memory size, number of environments, size).
        The internal representation will use _tensor_<name> as the name of the class property

        :param name: Tensor name (the name has to follow the python PEP 8 style)
        :type name: str
        :param size: Number of elements in the last dimension (effective data size).
                     The product of the elements will be computed for sequences or gym/gymnasium spaces
        :type size: int, tuple or list of integers, gym.Space, or gymnasium.Space
        :param dtype: Data type (torch.dtype) (default: ``None``).
                      If None, the global default torch data type will be used
        :type dtype: torch.dtype or None, optional
        :param keep_dimensions: Whether or not to keep the dimensions defined through the size parameter (default: ``False``)
        :type keep_dimensions: bool, optional

        :raises ValueError: The tensor name exists already but the size or dtype are different

        :return: True if the tensor was created, otherwise False
        :rtype: bool
        """
        # compute data size
        size = self._get_space_size(size, keep_dimensions)
        # check dtype and size if the tensor exists
        if name in self.tensors:
            tensor = self.tensors[name]
            if tensor.size(-1) != size:
                raise ValueError(f"Size of tensor {name} ({size}) doesn't match the existing one ({tensor.size(-1)})")
            if dtype is not None and tensor.dtype != dtype:
                raise ValueError(f"Dtype of tensor {name} ({dtype}) doesn't match the existing one ({tensor.dtype})")
            return False
        # define tensor shape
        tensor_shape = (self.memory_size, self.num_envs, *size) if keep_dimensions else (self.memory_size, self.num_envs, size)
        view_shape = (-1, *size) if keep_dimensions else (-1, size)
        # create tensor (_tensor_<name>) and add it to the internal storage
        setattr(self, f"_tensor_{name}", torch.zeros(tensor_shape, device=self.device, dtype=dtype))
        # update internal variables
        self.tensors[name] = getattr(self, f"_tensor_{name}")
        self.tensors_view[name] = self.tensors[name].view(*view_shape)
        self.tensors_keep_dimensions[name] = keep_dimensions
        # fill the tensors (float tensors) with NaN
        for tensor in self.tensors.values():
            if torch.is_floating_point(tensor):
                tensor.fill_(float("nan"))
        return True

I think it should be changed to 0.0

Toni-SM commented 18 hours ago

Hi @tobottyx

Filling memories with NaN when creating the tensors is done on purpose by design as a case for identifying when memory has not been configured or used properly. It is better to have an early NaN exception than for the algorithm to learn from invalid or wrong memory data (0.0 or any other initial value) :sweat_smile:.

As @asdfGuest mentioned, in the case of PPO (when all the data collected during rollout is sampled from memory during ttarinign), memory's memory_size should have the same value (no more, no less) as agent's rollouts.