pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.37k stars 315 forks source link

[BUG] Recurrent DQN example is broken #1381

Closed smorad closed 1 year ago

smorad commented 1 year ago

Describe the bug

The tutorial at https://github.com/pytorch/rl/blob/main/tutorials/sphinx-tutorials/dqn_with_rnn.py produces an error when run.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

python tutorials/sphinx-tutorials/dqn_with_rnn.py
/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
in_keys ['embed', 'recurrent_state_h', 'recurrent_state_c', 'is_init']
out_keys ['embed', ('next', 'recurrent_state_h'), ('next', 'recurrent_state_c')]
TransformedEnv(
    env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            ToTensorImage(keys=['pixels']),
            GrayScale(keys=['pixels']),
            Resize(w=84, h=84, interpolation=InterpolationMode.BILINEAR, keys=['pixels']),
            StepCounter(keys=[]),
            InitTracker(keys=[]),
            RewardScaling(loc=0.0000, scale=0.1000, keys=['reward']),
            ObservationNorm(keys=['pixels']),
            TensorDictPrimer(primers={('recurrent_state_h',): UnboundedContinuousTensorSpec(
                shape=torch.Size([1, 128]),
                space=None,
                device=cpu,
                dtype=torch.float32,
                domain=continuous), ('recurrent_state_c',): UnboundedContinuousTensorSpec(
                shape=torch.Size([1, 128]),
                space=None,
                device=cpu,
                dtype=torch.float32,
                domain=continuous)}, default_value=0.0, random=False)))
  0%|                                                                                                                                                                                                                                                                                          | 0/1000000 [00:00<?, ?it/s]Let us print the first batch of data.
Pay attention to the key names which will reflect what can be found in this data structure, in particular: the output of the QValueModule (action_values, action and chosen_action_value),the 'is_init' key that will tell us if a step is initial or not, and the recurrent_state keys.
 TensorDict(
    fields={
        action: Tensor(shape=torch.Size([50, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([50, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([50]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([50]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        embed: Tensor(shape=torch.Size([50, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([50, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
                recurrent_state_c: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
                recurrent_state_h: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([50]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([50, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_c: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_h: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([50]),
    device=cpu,
    is_shared=False)
  0%|                                                                                                                                                                                                                                                                              | 50/1000000 [00:00<2:11:58, 126.28it/s]Traceback (most recent call last):
  File "/Users/smorad/code/explore/torchrl_src/tutorials/sphinx-tutorials/dqn_with_rnn.py", line 378, in <module>
    loss_vals = loss_fn(s)
                ^^^^^^^^^^
  File "/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/torchrl_src/torchrl/objectives/dqn.py", line 306, in forward
    target_value = self.value_estimator.value_estimate(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/torchrl_src/torchrl/objectives/value/advantages.py", line 563, in value_estimate
    next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/torchrl_src/torchrl/objectives/value/advantages.py", line 373, in _next_value
    self.value_network(step_td, **kwargs)
  File "/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 551, in new_fun
    old_params = _assign_params(
                 ^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 628, in _assign_params
    return _swap_state(module, params, make_stateless, return_old_tensordict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 368, in _swap_state
    _old_value = _swap_state(
                 ^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 368, in _swap_state
    _old_value = _swap_state(
                 ^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 368, in _swap_state
    _old_value = _swap_state(
                 ^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 392, in _swap_state
    setattr(model, key, value)
  File "/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 146, in __setattr__
    super().__setattr__(attr, value)
  File "/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1635, in __setattr__
    raise TypeError("cannot assign '{}' as parameter '{}' "
TypeError: cannot assign 'torch.FloatTensor' as parameter 'weight_ih_l0' (torch.nn.Parameter or None expected)
  0%|                                                                                  

System info

Describe the characteristic of your environment:

Reason and Possible fixes

It might be due to how the torch RNN classes do strange things with their weights for efficiency.

Checklist

vmoens commented 1 year ago

On it!

vmoens commented 1 year ago

I see where that comes from When we use the functional RNN, we pass a set of target params that are tensors (not nn.Parameters) but the RNN isn't happy to receive a reguar tensor and not a parameter. The solution will be to use nn.Parameters even for detached params I guess, though it's goona be a bit painful

smorad commented 1 year ago

An alternative might be to use an LSTMCell but it would be slower.