pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] Function `make_tensordict_primer` Overlooks Batch-Locked Envs #2323

Closed ErcBunny closed 1 month ago

ErcBunny commented 1 month ago

Describe the bug

Appending the TensorDictPrimer transform created from LSTMModule.make_tensordict_primer triggers dimension error.

To Reproduce

import torchrl
from torchrl.envs import TransformedEnv, InitTracker, Compose
from torchrl.collectors import SyncDataCollector

env = TransformedEnv(
    original_env,  # a batch-locked environment
    Compose(
        InitTracker(),
        lstm.make_tensordict_primer(),
    ),
)

data_collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=10,
    total_frames=100,
)
...
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torchrl/collectors/collectors.py", line 703, in __init__
    policy_output = self.policy(policy_input)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/utils.py", line 261, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/sequence.py", line 428, in forward
    tensordict = self._run_module(module, tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/sequence.py", line 409, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/utils.py", line 261, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/sequence.py", line 428, in forward
    tensordict = self._run_module(module, tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/tensordict/nn/sequence.py", line 409, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torchrl/modules/tensordict_module/rnn.py", line 650, in forward
    val, hidden0, hidden1 = self._lstm(
  File "/home/lyq/mambaforge/envs/rlgpu/lib/python3.8/site-packages/torchrl/modules/tensordict_module/rnn.py", line 701, in _lstm
    _hidden0_in.transpose(-3, -2).contiguous(),
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)

Expected behavior

Should run without error, or when appending the TensorDictPrimer transform, its shapes should be checked.

System info

torchrl==0.4

Additional context

There is a similar issue on: https://github.com/pytorch/rl/issues/1493

Reason and Possible fixes

Function make_tensordict_primer overlooks batch-locked envs as it's source code is:

def make_tensordict_primer(self):
    # ...
    return TensorDictPrimer(
        {
            in_key1: UnboundedContinuousTensorSpec(
                shape=(self.lstm.num_layers, self.lstm.hidden_size)
            ),
            in_key2: UnboundedContinuousTensorSpec(
                shape=(self.lstm.num_layers, self.lstm.hidden_size)
            ),
        }
    )

In this case users can manually add the transform with proper shapes.

Checklist

vmoens commented 1 month ago

Not exactly sure what is going on in your code but this works fine with me

from torchrl.collectors import SyncDataCollector
from torchrl.envs import TransformedEnv, InitTracker
from torchrl.envs import GymEnv
from torchrl.modules import MLP, LSTMModule
from torch import nn
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod

env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
assert env.base_env.batch_locked
lstm_module = LSTMModule(
    input_size=env.observation_spec["observation"].shape[-1],
    hidden_size=64,
    in_keys=["observation", "rs_h", "rs_c"],
    out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
mlp = MLP(num_cells=[64], out_features=1)
policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
policy(env.reset())
env = env.append_transform(lstm_module.make_tensordict_primer())
data_collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=10
)
for data in data_collector:
    print(data)
    break

and the env is batch-locked. Can you give a runnable example perhaps?

ErcBunny commented 1 month ago

Thanks for the reply!

Sorry for mis-using the word "batch-locked", I meant vectorized environments with batch size larger than one.

My env is a derived env from EnvBase that creates Isaac Gym vectorized envs under the hood. I tried to use the wrapper but it failed to create the envs (#2292) so I decided to follow the tutorial to define a simple wrapper.

A rollout of 16 steps and 64 envs looks like this:

original_env = MyEnv(
    cfg=cfg_dict["task"],
    device=cfg["sim_device"],
    graphics_device_id=cfg["graphics_device_id"],
    headless=cfg["headless"],
    force_render=cfg["force_render"],
)
check_env_specs(env)

target_num_steps = 200
t0 = time.time()
with torch.no_grad():
    rollout_data = env.rollout(target_num_steps)
t1 = time.time()
print(rollout_data)
2024-07-26 18:34:09,664 [torchrl][INFO] check_env_specs succeeded!
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([64, 16, 4]), device=cuda:0, dtype=torch.float32, is_shared=True),
        done: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: TensorDict(
                    fields={
                        depth_image: Tensor(shape=torch.Size([64, 16, 1, 192, 256]), device=cuda:0, dtype=torch.float32, is_shared=True),
                        drone_state: Tensor(shape=torch.Size([64, 16, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
                    batch_size=torch.Size([64, 16]),
                    device=cuda:0,
                    is_shared=True),
                reward: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                terminated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                truncated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
            batch_size=torch.Size([64, 16]),
            device=cuda:0,
            is_shared=True),
        observation: TensorDict(
            fields={
                depth_image: Tensor(shape=torch.Size([64, 16, 1, 192, 256]), device=cuda:0, dtype=torch.float32, is_shared=True),
                drone_state: Tensor(shape=torch.Size([64, 16, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
            batch_size=torch.Size([64, 16]),
            device=cuda:0,
            is_shared=True),
        terminated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        truncated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([64, 16]),
    device=cuda:0,
    is_shared=True)

I also tried to modify your code with ParallelEnv:

from torchrl.collectors import SyncDataCollector
from torchrl.envs import TransformedEnv, InitTracker
from torchrl.envs import GymEnv
from torchrl.modules import MLP, LSTMModule
import torch
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
from torchrl.envs import ParallelEnv

def env_make():
    return GymEnv("Pendulum-v1")

env = TransformedEnv(ParallelEnv(3, env_make), InitTracker())
assert env.base_env.batch_locked

lstm_module = LSTMModule(
    input_size=env.observation_spec["observation"].shape[-1],
    hidden_size=64,
    in_keys=["observation", "rs_h", "rs_c"],
    out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
)
mlp = MLP(num_cells=[64], out_features=1)
policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
policy(env.reset())

env = env.append_transform(lstm_module.make_tensordict_primer())

data_collector = SyncDataCollector(env, policy, frames_per_batch=10)
for data in data_collector:
    print(data)
    break

but it gave me errors when executing env.reset():

...
File "/home/lyq/Developer/isaacgym_workspace/scratch/test_td_primer.py", line 25, in <module>
  policy(env.reset())
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
  tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 809, in _reset
  tensordict_reset = self.base_env._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  return func(*args, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 56, in decorated_fun
  self._start_workers()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1275, in _start_workers
  process.start()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 121, in start
  self._popen = self._Popen(self)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
  return _default_context.get_context().Process._Popen(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
  return Popen(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
  super().__init__(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
  self._launch(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 42, in _launch
  prep_data = spawn.get_preparation_data(process_obj._name)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/spawn.py", line 154, in get_preparation_data
  _check_not_importing_main()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/spawn.py", line 134, in _check_not_importing_main
  raise RuntimeError('''
RuntimeError: 
      An attempt has been made to start a new process before the
      current process has finished its bootstrapping phase.

      This probably means that you are not using fork to start your
      child processes and you have forgotten to use the proper idiom
      in the main module:

          if __name__ == '__main__':
              freeze_support()
              ...

      The "freeze_support()" line can be omitted if the program
      is not going to be frozen to produce an executable.

Maybe I am not using parallel envs correctly... So I am afraid that I can't give you a running example, but I hope these tests can provide more info for the issue.

vmoens commented 1 month ago

Your rollouts looks ok to me.

The error you're seeing with ParallelEnv should be solved if you put it in a if __name__ == "__main__", like in this example where I wrap a Serial env (N=3) in a ParallelEnv (N=2):

from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs import TransformedEnv, InitTracker, ParallelEnv, SerialEnv
from torchrl.modules import MLP, GRUModule

def make_env():
    return TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())

if __name__ == "__main__":

    gru_module = GRUModule(
        input_size=make_env().observation_spec["observation"].shape[-1],
        hidden_size=64,
        in_keys=["observation", "rs"],
        out_keys=["intermediate", ("next", "rs")])
    mlp = MLP(num_cells=[64], out_features=1)
    policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
    primer = gru_module.make_tensordict_primer()
    env = ParallelEnv(2,
                      lambda primer=primer:
                      SerialEnv(3, make_env).append_transform(primer.clone()))
    reset = env.reset()
    print('reset', reset)
    policy(reset)
    print('reset after policy', reset)

    data_collector = SyncDataCollector(
        env,
        policy,
        frames_per_batch=10
    )
    for data in data_collector:
        print("data from rollout", data)
        break

Since I can't repro I'm closing this, but if you still encounter an error feel free to re-open

ErcBunny commented 1 month ago

Thx for the example. But I get errors and cannot complete the rollout. Is it because I am using version 0.4?

11:59:40  |mujoco|lyq@xpg scratch → python test_td_primer.py 
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-gpackages/torchrl/envs/common.py:2989: DeprecationWarning: Your wrapper was not given a device. Currently, this value will default to 'cpu'. From v0.5 it will default to `None`. With a device of None, no device casting is performed and the resulting tensordicts are deviceless. Please set your device accordingly.
  warnings.warn(
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py:2989: DeprecationWarning: Your wrapper was not given a device. Currently, this value will default to 'cpu'. From v0.5 it will default to `None`. With a device of None, no device casting is performed and the resulting tensordicts are deviceless. Please set your device accordingly.
  warnings.warn(
Process _ProcessNoWarn-2:
Traceback (most recent call last):
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 668, in run
    return mp.Process.run(self, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1765, in _run_worker_pipe_shared_mem
    cur_td = env.reset(
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 814, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 4723, in _reset
    expand_as_right(_reset, value), value, prev_val
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/tensordict/utils.py", line 331, in expand_as_right
    raise RuntimeError(
RuntimeError: tensor shape is incompatible with dest shape, got: tensor.shape=torch.Size([3]), dest=torch.Size([1, 64])
Process _ProcessNoWarn-1:
Traceback (most recent call last):
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 668, in run
    return mp.Process.run(self, *args, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1765, in _run_worker_pipe_shared_mem
    cur_td = env.reset(
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 814, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 4723, in _reset
    expand_as_right(_reset, value), value, prev_val
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/tensordict/utils.py", line 331, in expand_as_right
    raise RuntimeError(
RuntimeError: tensor shape is incompatible with dest shape, got: tensor.shape=torch.Size([3]), dest=torch.Size([1, 64])
reset TensorDict(
    fields={
        done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=cpu,
    is_shared=False)
reset after policy TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([2, 3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=cpu,
    is_shared=False)
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/collectors/collectors.py:618: UserWarning: frames_per_batch (10) is not exactly divisible by the number of batched environments (6),  this results in more frames_per_batch per iteration that requested (12).To silence this message, set the environment variable RL_WARNINGS to False.
  warnings.warn(
Traceback (most recent call last):
  File "test_td_primer.py", line 63, in <module>
    data_collector = SyncDataCollector(
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/collectors/collectors.py", line 633, in __init__
    self._shuttle = self.env.reset()
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 59, in decorated_fun
    _check_for_faulty_process(self._workers)
  File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 124, in _check_for_faulty_process
    raise RuntimeError(
RuntimeError: At least one process failed. Check for more infos in the log.
vmoens commented 1 month ago

Could be! I'll make the 0.5 release soon (somewhere this week), let's see if it fixes it (you can also use nightlies to check!)

ErcBunny commented 1 month ago

Could be! I'll make the 0.5 release soon (somewhere this week), let's see if it fixes it (you can also use nightlies to check!)

In v0.5, make_tensordict_primer no longer throws IndexError. Super nice!