pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.21k stars 290 forks source link

[BUG] Shape Mismatch with default_info_dict_reader #1841

Closed BeFranke closed 7 months ago

BeFranke commented 7 months ago

Describe the bug

When trying to use the default_info_dict_reader to pass a single value per step from the info dict to the model, an error about a shape mismatch occurs.

To Reproduce

The example (using a safety-gymnasium-env as an example, see https://github.com/PKU-Alignment/safety-gymnasium)

from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.utils import check_env_specs
import safety_gymnasium as gym
from safety_gymnasium.wrappers import SafetyGymnasium2Gymnasium
from gymnasium.wrappers import RecordVideo
from torchrl.envs import default_info_dict_reader

env = SafetyGymnasium2Gymnasium(gym.make('SafetyPointGoal1-v0'))
env = GymWrapper(env)
env = env.set_info_dict_reader(default_info_dict_reader(["cost"]))
env = TransformedEnv(
    env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(
            in_keys=["observation"],
        ),
        StepCounter(),
    ),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

check_env_specs(env)

leads to the following error

Traceback (most recent call last):
  File "mre.py", line 33, in <module>
    check_env_specs(env)
  File "<projectdir>/.venv/lib/python3.8/site-packages/torchrl/envs/utils.py", line 465, in check_env_specs
    _per_level_env_check(
  File "<projectdir>/.venv/lib/python3.8/site-packages/torchrl/envs/utils.py", line 397, in _per_level_env_check
    _per_level_env_check(_data0, _data1, check_dtype=check_dtype)
  File "<projectdir>/.venv/lib/python3.8/site-packages/torchrl/envs/utils.py", line 392, in _per_level_env_check
    raise AssertionError(
AssertionError: The shapes of the real and fake tensordict don't match for key cost. Got fake=torch.Size([3, 1]) and real=torch.Size([3]).

It seems that somewhere, the single-element tensors are stacked differently.

Expected behavior

Example should run and yield an environment that has an additional key 'cost' with one value per step.

System info

Describe the characteristic of your environment: Installed via pip install torchrl python 3.8 torchrl 0.2.1

Checklist

vmoens commented 7 months ago

Should be easy to fix! Envs sometimes have a bit of problems dealing with scalar values

vmoens commented 7 months ago

By the way you will probably want to check out the new API for info dicts:

https://pytorch.org/rl/reference/generated/torchrl.envs.GymLikeEnv.html#torchrl.envs.GymLikeEnv.auto_register_info_dict

In the next release v0.3 (end of next week)