The calculation of the GAE with a recurrent critic fails with the error:
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
It seems that setting the flag shifted to True prevents this error.
Is this behavior expected? If so, should we maybe document that the shifted flag is necessary for recurrent critics?
To Reproduce
Minimal snippet to reproduce the issue:
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import GymEnv, TransformedEnv, transforms
from torchrl.envs.utils import check_env_specs
from torchrl.modules import LSTMModule
from torchrl.objectives.value import GAE
env = GymEnv(env_name="HalfCheetah-v4", device="cpu")
env = TransformedEnv(env)
env.append_transform(transforms.DoubleToFloat(in_keys=["observation"]))
env.append_transform(transforms.InitTracker())
env.append_transform(
transforms.TensorDictPrimer(
{
"recurrent_state_h": UnboundedContinuousTensorSpec(shape=(1, 128)),
"recurrent_state_c": UnboundedContinuousTensorSpec(shape=(1, 128)),
}
)
)
check_env_specs(env)
observation_size = env.observation_spec["observation"].shape[-1]
action_size = env.action_spec.shape[-1]
rnn = LSTMModule(
input_size=observation_size,
hidden_size=128,
num_layers=1,
device="cpu",
in_key="observation",
out_key="features",
)
value_net = TensorDictModule(
module=nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1),
),
in_keys=["features"],
out_keys=["state_value"],
)
critic_module = TensorDictSequential(rnn, value_net)
collector = SyncDataCollector(
env,
None,
frames_per_batch=512,
device="cpu",
)
batch = collector.next()
# With shifted=True calculation of advantages works
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=True)
with torch.no_grad():
advantage_module(batch)
# With shifted=False calculation of advantage fails!
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=False)
with torch.no_grad():
# NOTE: Should raise
# RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow (...)
advantage_module(batch)
System info
> pip list | grep torch
torch 2.4.0
torchrl 0.5.0
[x] I have checked that there is no similar issue in the repo (required)
Describe the bug
The calculation of the GAE with a recurrent critic fails with the error:
It seems that setting the flag
shifted
toTrue
prevents this error.Is this behavior expected? If so, should we maybe document that the
shifted
flag is necessary for recurrent critics?To Reproduce
Minimal snippet to reproduce the issue:
System info