pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.25k stars 297 forks source link

[BUG] Documentation of BinaryDiscreteTensorSpec can be confusing #2364

Closed albertbou92 closed 2 months ago

albertbou92 commented 2 months ago

Describe the bug

The documentation of the class BinaryDiscreteTensorSpec can be confusing, as pointed out in https://github.com/pytorch/rl/discussions/2344#discussioncomment-10205806.

Here is the documentation: https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py#L3228 Which says,

"""A binary discrete tensor spec.

Args:
    n (int): length of the binary vector.
    shape (torch.Size, optional): total shape of the sampled tensors.
        If provided, the last dimension must match n.
    device (str, int or torch.device, optional): device of the tensors.
    dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.

Examples:
    >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
    >>> print(spec.zero())
"""

The n argument

At the moment, since BinaryDiscreteTensorSpec inherits from DiscreteTensorSpec, n controls the number of outputs. n=1 means only False values, n=2 allows True and False values and n>2 also allows True and False . This is not very intuitive given the explanation in the documentation. Also, is this the desired behaviour?

Additionally, n has to match the last dimension of the shape and an error is raised otherwise. Is this necessary?

Possible simplification

We could remove the n parameter (fixing it to 2) and simplify the init method signature by allowing the user to define the spec shape only with the parameter shape.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torchrl
Traceback (most recent call last):
  File ... 

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

vmoens commented 2 months ago

Can you check #2366 and let me know if that fixes it?

albertbou92 commented 2 months ago

Yes, the class Binary fixes it.