pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[BUG] `EnvBase.step_and_maybe_reset(td)` modifies the ('next','observation') data too on partial reset with`NonTensorStack` #2257

Closed jkrude closed 1 week ago

jkrude commented 1 week ago

Describe the bug

For a custom environment with NonTensorData calling tensordict, tensordict_ = step_and_maybe_reset(tensordict) changes both the (next, observation) entry of the input tensordict (unexpected), as well as the observation entry of tensordict_ which is partially been reset (expected).

To Reproduce

This Environment is hard coded for batch_size = (2,). The observation space is just a string for simplicity. _step always returns ["B", "Z"] as next observation, with the first batch entry being in a done state but not the second. _reset always returns ["A", "C"] as initial observation after reset. (The action is ignored and only included to comply with the spec)

from typing import Optional

from torchrl.data import CompositeSpec, NonTensorSpec, BinaryDiscreteTensorSpec
from torchrl.envs import EnvBase
from tensordict import TensorDictBase, TensorDict, NonTensorData, NonTensorStack
import torch

class CustomEnv(EnvBase):
    # Custom environment

    def __init__(
        self,
        *,
        device=None,
        batch_size: Optional[torch.Size] = torch.Size([2]),
        run_type_checks: bool = False,
        allow_done_after_reset: bool = False,
    ):
        assert batch_size == (2,)  # hardcoded for minimal example
        super().__init__(
            device=device,
            batch_size=batch_size,
            run_type_checks=run_type_checks,
            allow_done_after_reset=allow_done_after_reset,
        )

        self.observation_spec = CompositeSpec(
            observation=NonTensorSpec(shape=batch_size), shape=batch_size
        )
        self.action_spec = NonTensorSpec(shape=batch_size)
        self.reward_spec: BinaryDiscreteTensorSpec = BinaryDiscreteTensorSpec(
            n=1, dtype=torch.int8, shape=torch.Size([2, 1])
        )

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        done = torch.tensor([True, False], dtype=torch.bool)
        next_observation = NonTensorStack(
            NonTensorData("B"), NonTensorData("Z"), batch_size=(2,)
        )
        return TensorDict(
            {"observation": next_observation, "done": done, "reward": torch.ones((2,))},
            batch_size=(2,),
        )

    def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
        return TensorDict(
            {
                "observation": NonTensorStack(
                    NonTensorData("A"), NonTensorData("C"), batch_size=(2,)
                )
            },
            batch_size=(2,),
        )

    def rand_action(self, tensordict: Optional[TensorDictBase] = None):
        action = NonTensorStack(NonTensorData("+"), NonTensorData("+"), batch_size=(2,))
        if tensordict is None:
            tensordict = TensorDict({}, batch_size=self.batch_size)
        tensordict["action"] = action
        return tensordict

    def _set_seed(self, seed: Optional[int]):
        pass

env = CustomEnv()
td = env.reset()
env.rand_action(td)
out_td, reset_td = env.step_and_maybe_reset(td)
assert out_td is td
assert torch.equal(td["next", "done"], torch.tensor([[True], [False]]))
observation = "observation"
next_observation = ("next", observation)

print(f"{td[next_observation]=}")
print(f"{reset_td[observation]=}")
td[next_observation]=['A', 'Z']
reset_td[observation]=['A', 'Z']

Expected behavior

After taking one step, and executing out_td, reset_td = env.step_and_maybe_reset(td) we expect that td is unchanged, especially td["next","observation"] and reset_td having the observation being reset in the first dimension but not the second. Specifically, we expect td["next","observation"]=["B","Z"] and reset_td["observation"] = ["A","Z"].

However, both td["next","observation"] and reset_td["observation"] are both ["A", "Z"].

System info

The library was installed using pip requirements. We use the nightly-release.

tensordict-nightly>=2024.6.19
torch >= 2.4.0.dev
torchrl-nightly>=2024.6.23
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.6.23 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux

Additional context

The problem occurs only for partial resets (not all batch entries are done) and is likely correlated with pytorch/tensordict#837. Interestingly using the latest releases (0.4.0 1.26.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux) I get another wrong result:

td[next_observation]=['B', 'Z']
reset_td[observation]=['B', 'Z']

Checklist

vmoens commented 1 week ago

On it! Thanks for reporting