pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.34k stars 309 forks source link

[BUG] Calculation of GAE fails with recurrent critic #2372

Closed thomasbbrunner closed 3 months ago

thomasbbrunner commented 3 months ago

Describe the bug

The calculation of the GAE with a recurrent critic fails with the error:

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

It seems that setting the flag shifted to True prevents this error.

Is this behavior expected? If so, should we maybe document that the shifted flag is necessary for recurrent critics?

To Reproduce

Minimal snippet to reproduce the issue:

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import GymEnv, TransformedEnv, transforms
from torchrl.envs.utils import check_env_specs
from torchrl.modules import LSTMModule
from torchrl.objectives.value import GAE

env = GymEnv(env_name="HalfCheetah-v4", device="cpu")
env = TransformedEnv(env)
env.append_transform(transforms.DoubleToFloat(in_keys=["observation"]))
env.append_transform(transforms.InitTracker())
env.append_transform(
    transforms.TensorDictPrimer(
        {
            "recurrent_state_h": UnboundedContinuousTensorSpec(shape=(1, 128)),
            "recurrent_state_c": UnboundedContinuousTensorSpec(shape=(1, 128)),
        }
    )
)
check_env_specs(env)

observation_size = env.observation_spec["observation"].shape[-1]
action_size = env.action_spec.shape[-1]

rnn = LSTMModule(
    input_size=observation_size,
    hidden_size=128,
    num_layers=1,
    device="cpu",
    in_key="observation",
    out_key="features",
)

value_net = TensorDictModule(
    module=nn.Sequential(
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
    ),
    in_keys=["features"],
    out_keys=["state_value"],
)
critic_module = TensorDictSequential(rnn, value_net)

collector = SyncDataCollector(
    env,
    None,
    frames_per_batch=512,
    device="cpu",
)

batch = collector.next()

# With shifted=True calculation of advantages works
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=True)
with torch.no_grad():
    advantage_module(batch)

# With shifted=False calculation of advantage fails!
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=False)
with torch.no_grad():
    # NOTE: Should raise
    # RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow (...)
    advantage_module(batch)

System info

> pip list | grep torch
torch                          2.4.0
torchrl                        0.5.0
vmoens commented 3 months ago

2376 should fix it.

You'll still need to add python_based=True in your LSTMModule