pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.23k stars 295 forks source link

GAE does not support LSTM-based value network. #2444

Open levelrin opened 1 week ago

levelrin commented 1 week ago

Motivation

I got the following error when I used GAE with an LSTM-based value network:

RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.

Here is the code I ran:

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE

class ValueNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=2,
            hidden_size=1,
            num_layers=1,
            bidirectional=False,
            batch_first=True
        )

    def forward(self, i):
        output, (hidden_state, cell_state) = self.lstm(i)
        return hidden_state

def main():
    value_network = ValueNetwork()
    value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_dict_module
    )
    gae.set_keys(
        advantage="advantage",
        value_target="value_target",
        value="value",
    )
    tensor_dict = TensorDict({
        "next": {
            "observation": torch.FloatTensor([
                [[8, 9], [10, 11]],
                [[12, 13], [14, 15]]
            ]),
            "reward": torch.FloatTensor([[1], [-1]]),
            "done": torch.BoolTensor([[1], [1]]),
            "terminated": torch.BoolTensor([[1], [1]])
        },
        "observation": torch.FloatTensor([
            [[0, 1], [2, 3]],
            [[4, 5], [6, 7]]
        ])
    }, batch_size=2)

    output_tensor_dict = gae(tensor_dict)
    print(f"output_tensor_dict: {output_tensor_dict}")
    advantage = output_tensor_dict["advantage"]
    print(f"advantage: {advantage}")

main()

The error was caused by this exact line:

output_tensor_dict = gae(tensor_dict)

I tried using unbatched input and realized that GAE does not support unbatched input. For example, this is the unbatched input I tried:

tensor_dict = TensorDict({
    "next": {
        "observation": torch.FloatTensor([[4, 5], [6, 7]]),
        "reward": torch.FloatTensor([1]),
        "done": torch.BoolTensor([1]),
        "terminated": torch.BoolTensor([1])
    },
    "observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])

And I got this error from GAE:

RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])

Therefore, I concluded that GAE does not support an LSTM-based value network.

Solution

GAE should support an LSTM-based value network.

Alternatives

GAE should support unbatched tensor dict as an input.

Additional context

I'm using torchrl version: 0.5.0.

I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.

Checklist

thomasbbrunner commented 11 hours ago

I see that you're using the Torch LSTM in your snippet. Maybe try with TorchRL's version of it (LSTMModule)?

Also, as described in #2372, you'll need to use python_based=True in your LSTMModule.