pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] `_complete_done` always sets missing terminated to `False` #2291

Closed jkrude closed 1 month ago

jkrude commented 1 month ago

Describe the bug

For a custom environment that only outputs "done" keys in _step, as in the example, the automatically added "terminated" key will always be False across all dimensions and not mirror the done-tensor.

This happens regardless of whether the user defines the done spec or it automatically added in EnvBase. _create_done_specs.

To Reproduce

I tried to keep the example to a minimum. The important part is in _step where we add the "done" key as torch.tensor([True, False)].

from typing import Optional

import torch
from tensordict import TensorDictBase, TensorDict
from torchrl.data import (
    CompositeSpec,
    BinaryDiscreteTensorSpec,
    UnboundedContinuousTensorSpec,
    OneHotDiscreteTensorSpec,
)
from torchrl.envs import EnvBase

class CustomEnv(EnvBase):

    def __init__(
        self,
        *,
        device=None,
        batch_size: Optional[torch.Size] = torch.Size([2]),
        run_type_checks: bool = False,
        allow_done_after_reset: bool = False,
    ):
        assert batch_size == (2,)  # hardcoded for minimal example
        super().__init__(
            device=device,
            batch_size=batch_size,
            run_type_checks=run_type_checks,
            allow_done_after_reset=allow_done_after_reset,
        )

        self.observation_spec = CompositeSpec(
            observation=UnboundedContinuousTensorSpec(
                shape=torch.Size(batch_size + (1,))
            ),
            shape=batch_size,
        )
        self.action_spec = OneHotDiscreteTensorSpec(n=2, shape=batch_size)
        self.reward_spec: BinaryDiscreteTensorSpec = BinaryDiscreteTensorSpec(
            n=1, dtype=torch.int8, shape=torch.Size([2, 1])
        )

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        done = torch.tensor([True, False], dtype=torch.bool)
        next_observation = torch.randn(self.observation_spec["observation"].shape)
        return TensorDict(
            {"observation": next_observation, "done": done, "reward": torch.ones((2,))},
            batch_size=(2,),
        )

    def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
        return TensorDict(
            {"observation": torch.randn(self.observation_spec["observation"].shape)},
            batch_size=(2,),
        )

    def _set_seed(self, seed: Optional[int]):
        pass

  env = CustomEnv()
  td = env.reset()
  env.rand_action(td)
  env.step(td)
  assert env.done_keys == ["done", "terminated"]
  assert torch.equal(td[("next", "done")],torch.tensor([[True], [False]]))
  assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
terminated_env_bug.py:62 (test_)
def test_():
        env = CustomEnv()
        td = env.reset()
        env.rand_action(td)
        env.step(td)
        assert env.done_keys == ["done", "terminated"]
        assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]]))
>       assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
E       assert False
E        +  where False = <built-in method equal of type object at 0x7fabbdc64800>(tensor([[False],\n        [False]]), tensor([[ True],\n        [False]]))
E        +    where <built-in method equal of type object at 0x7fabbdc64800> = torch.equal
E        +    and   tensor([[ True],\n        [False]]) = <built-in method tensor of type object at 0x7fabbdc64800>([[True], [False]])
E        +      where <built-in method tensor of type object at 0x7fabbdc64800> = torch.tensor

Expected behavior

The ("next", "terminated") entry is equal to the ("next", "done") entry as documented for EnvBase._complete_done.

System info

Using torchrl-nightly installed with pip. But should also apply for the main-branch as relevant code is the same.

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.7.3 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux

Reason and Possible fixes

The problem seems to come from EnvBase._complete_done . Thefor key, item in done_spec.items(False, True): loop in Line 1509 iterates both over the "done" and the "terminated" key, however only for "done" a value is present in data. For the "done" key (happening first) data.set("terminated", val) is set correctly to the values of data["done"] in line 1537. But then for the "terminated" key the elif val is None: is triggered and data["terminated"] is overridden again.

...
for key, item in done_spec.items(False, True):  # goes over done and terminated (order is important)
                val = vals.get(key, None)  # will be [[True], [False]] for "done" but None for "terminated"
                if (
                    key == "done"
                    and val is not None
                    and "terminated" in done_spec_keys
                    and "terminated" not in data_keys
                ):
                    if "truncated" in data_keys:
                        raise RuntimeError(
                            "Cannot infer the value of terminated when only done and truncated are present."
                        )
                    data.set("terminated", val)
                elif (
                    key == "terminated"
                    and val is not None
                    and "done" in done_spec_keys
                    and "done" not in data_keys
                ):
                    if "truncated" in data_keys:
                        done = val | data.get("truncated")
                        data.set("done", done)
                    else:
                        data.set("done", val)
                elif val is None:
                    # we must keep this here: we only want to fill with 0s if we're sure
                    # done should not be copied to terminated or terminated to done
                    # in this case, just fill with 0s
                    data.set(key, item.zero(leading_dim))  # overrides the "terminated" key with False
        return data

cc @kurtamohler