For a custom environment that only outputs "done" keys in _step, as in the example, the automatically added "terminated" key will always be False across all dimensions and not mirror the done-tensor.
This happens regardless of whether the user defines the done spec or it automatically added in EnvBase. _create_done_specs.
To Reproduce
I tried to keep the example to a minimum.
The important part is in _step where we add the "done" key as torch.tensor([True, False)].
terminated_env_bug.py:62 (test_)
def test_():
env = CustomEnv()
td = env.reset()
env.rand_action(td)
env.step(td)
assert env.done_keys == ["done", "terminated"]
assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]]))
> assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
E assert False
E + where False = <built-in method equal of type object at 0x7fabbdc64800>(tensor([[False],\n [False]]), tensor([[ True],\n [False]]))
E + where <built-in method equal of type object at 0x7fabbdc64800> = torch.equal
E + and tensor([[ True],\n [False]]) = <built-in method tensor of type object at 0x7fabbdc64800>([[True], [False]])
E + where <built-in method tensor of type object at 0x7fabbdc64800> = torch.tensor
Expected behavior
The ("next", "terminated") entry is equal to the ("next", "done") entry as documented for EnvBase._complete_done.
System info
Using torchrl-nightly installed with pip. But should also apply for the main-branch as relevant code is the same.
The problem seems to come from EnvBase._complete_done.
Thefor key, item in done_spec.items(False, True): loop in Line 1509 iterates both over the "done" and the "terminated" key, however only for "done" a value is present in data.
For the "done" key (happening first) data.set("terminated", val) is set correctly to the values of data["done"] in line 1537.
But then for the "terminated" key the elif val is None: is triggered and data["terminated"] is overridden again.
...
for key, item in done_spec.items(False, True): # goes over done and terminated (order is important)
val = vals.get(key, None) # will be [[True], [False]] for "done" but None for "terminated"
if (
key == "done"
and val is not None
and "terminated" in done_spec_keys
and "terminated" not in data_keys
):
if "truncated" in data_keys:
raise RuntimeError(
"Cannot infer the value of terminated when only done and truncated are present."
)
data.set("terminated", val)
elif (
key == "terminated"
and val is not None
and "done" in done_spec_keys
and "done" not in data_keys
):
if "truncated" in data_keys:
done = val | data.get("truncated")
data.set("done", done)
else:
data.set("done", val)
elif val is None:
# we must keep this here: we only want to fill with 0s if we're sure
# done should not be copied to terminated or terminated to done
# in this case, just fill with 0s
data.set(key, item.zero(leading_dim)) # overrides the "terminated" key with False
return data
Describe the bug
For a custom environment that only outputs "done" keys in
_step
, as in the example, the automatically added "terminated" key will always be False across all dimensions and not mirror the done-tensor.This happens regardless of whether the user defines the done spec or it automatically added in
EnvBase. _create_done_specs
.To Reproduce
I tried to keep the example to a minimum. The important part is in
_step
where we add the "done" key astorch.tensor([True, False)]
.Expected behavior
The ("next", "terminated") entry is equal to the ("next", "done") entry as documented for
EnvBase._complete_done
.System info
Using torchrl-nightly installed with pip. But should also apply for the main-branch as relevant code is the same.
Reason and Possible fixes
The problem seems to come from
EnvBase._complete_done
. Thefor key, item in done_spec.items(False, True):
loop in Line 1509 iterates both over the "done" and the "terminated" key, however only for "done" a value is present indata
. For the "done" key (happening first)data.set("terminated", val)
is set correctly to the values ofdata["done"]
in line 1537. But then for the "terminated" key theelif val is None:
is triggered anddata["terminated"]
is overridden again.cc @kurtamohler