Open levelrin opened 1 week ago
I got the following error when I used GAE with an LSTM-based value network:
RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.
Here is the code I ran:
import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule from torch import nn from torchrl.objectives.value import GAE class ValueNetwork(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM( input_size=2, hidden_size=1, num_layers=1, bidirectional=False, batch_first=True ) def forward(self, i): output, (hidden_state, cell_state) = self.lstm(i) return hidden_state def main(): value_network = ValueNetwork() value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"]) gae = GAE( gamma=0.98, lmbda=0.95, value_network=value_dict_module ) gae.set_keys( advantage="advantage", value_target="value_target", value="value", ) tensor_dict = TensorDict({ "next": { "observation": torch.FloatTensor([ [[8, 9], [10, 11]], [[12, 13], [14, 15]] ]), "reward": torch.FloatTensor([[1], [-1]]), "done": torch.BoolTensor([[1], [1]]), "terminated": torch.BoolTensor([[1], [1]]) }, "observation": torch.FloatTensor([ [[0, 1], [2, 3]], [[4, 5], [6, 7]] ]) }, batch_size=2) output_tensor_dict = gae(tensor_dict) print(f"output_tensor_dict: {output_tensor_dict}") advantage = output_tensor_dict["advantage"] print(f"advantage: {advantage}") main()
The error was caused by this exact line:
output_tensor_dict = gae(tensor_dict)
I tried using unbatched input and realized that GAE does not support unbatched input. For example, this is the unbatched input I tried:
GAE
tensor_dict = TensorDict({ "next": { "observation": torch.FloatTensor([[4, 5], [6, 7]]), "reward": torch.FloatTensor([1]), "done": torch.BoolTensor([1]), "terminated": torch.BoolTensor([1]) }, "observation": torch.FloatTensor([[0, 1], [2, 3]]) }, batch_size=[])
And I got this error from GAE:
RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])
Therefore, I concluded that GAE does not support an LSTM-based value network.
GAE should support an LSTM-based value network.
GAE should support unbatched tensor dict as an input.
I'm using torchrl version: 0.5.0.
torchrl
I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.
I see that you're using the Torch LSTM in your snippet. Maybe try with TorchRL's version of it (LSTMModule)?
LSTMModule
Also, as described in #2372, you'll need to use python_based=True in your LSTMModule.
python_based=True
Motivation
I got the following error when I used GAE with an LSTM-based value network:
Here is the code I ran:
The error was caused by this exact line:
I tried using unbatched input and realized that
GAE
does not support unbatched input. For example, this is the unbatched input I tried:And I got this error from
GAE
:Therefore, I concluded that
GAE
does not support an LSTM-based value network.Solution
GAE should support an LSTM-based value network.
Alternatives
GAE should support unbatched tensor dict as an input.
Additional context
I'm using
torchrl
version: 0.5.0.I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.
Checklist