pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.32k stars 307 forks source link

[Feature Request] Information of RNNs expected inputs and outputs difficult to access when part of larger architectures #1771

Open albertbou92 opened 10 months ago

albertbou92 commented 10 months ago

Motivation

When RNN’s are used in isolation, creating a TensorDictPrimer Transform for the environment to populate the TensorDicts with the expected tensors is pretty straightforward:

from torchrl.modules import GRUModule

gru_module = GRUModule(
    input_size=10,
    hidden_size=10,
    num_layers=1,
    in_keys=["input", "recurrent_state", "is_init"],
    out_keys=["output", ("next", "recurrent_state")],
)

transform = gru_module.make_tensordict_primer()

However, when RNN’s are part of a larger architecture, this can become tricky. e.g.

from torchrl.modules import GRUModule, MLP
from tensordict.nn import TensorDictModule, TensorDictSequential

gru_module = GRUModule(
    input_size=10,
    hidden_size=10,
    num_layers=1,
    in_keys=["input", "recurrent_state", "is_init"],
    out_keys=["features", ("next", "recurrent_state")],
)
head = TensorDictModule(
    MLP(
        in_features=10,
        out_features=10,
        num_cells=[],
    ),
    in_keys=["features"],
    out_keys=["output"],
)
model = TensorDictSequential(gru_module, head)

In case you know the architecture, it is still possible to do:

transform = model[0].make_tensordict_primer()

But this is not ideal. Besides, beyond creating the transform automatically, maybe the user is interested in knowing the required shapes and other information of the model inputs, which now has the RNN inputs and their own inputs.

Solution

A solution would be to make possible to access all the information about the model expected inputs and outputs from some model specs.

Maybe it should not be required to define specs during the creation of the model, but optionally adding input specs would facilitate creating the primer transform in these cases.

Checklist

albertbou92 commented 10 months ago

The way I solved it in my case was to create a custom spec for every model I have and simply assign it to the model --> model.rnn_spec = spec. This way I can acess the info in other parts of the code.

vmoens commented 7 months ago

What about

def get_primers_from_model(model):
    primers = []

    def make_primers(module):
        if hasattr(module, "make_tensordict_primer"):
            primers.append(module.make_tensordict_primer())

    model.apply(make_primers)

    if not primers:
        raise smt
    elif len(primers) == 1:
        return primers[0]
    else:
        return Compose(*primers)