pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.25k stars 297 forks source link

[BUG] Calling `to` fails on MultiDiscreteTensorSpec #2203

Closed Quinticx closed 4 months ago

Quinticx commented 4 months ago

Describe the bug

When calling to on a MultiDiscreteTensorSpec, it fails with TypeError: MultiDiscreteTensorSpec.__init__() got an unexpected keyword argument 'n'

To Reproduce

Steps to reproduce the behavior.

import torchrl 
actions=MultiDiscreteTensorSpec(nvec=[2])
actions.to(dest="cuda:0")
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "../python3.10/site-packages/torchrl/data/tensor_specs.py", line 3212, in to
    return self.__class__(
TypeError: MultiDiscreteTensorSpec.__init__() got an unexpected keyword argument 'n'

Expected behavior

The MultiDiscreteTensorSpec should be sent to device cuda:0.

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

Reason and Possible fixes

Changing to nvec here fixes the issue.

return self.__class__(
            nvec=self.nvec.to(dest),
            shape=None,
            device=dest_device,
            dtype=dest_dtype,
            mask=mask,
        )

Checklist