Open albertbou92 opened 10 months ago
The way I solved it in my case was to create a custom spec for every model I have and simply assign it to the model --> model.rnn_spec = spec. This way I can acess the info in other parts of the code.
What about
def get_primers_from_model(model):
primers = []
def make_primers(module):
if hasattr(module, "make_tensordict_primer"):
primers.append(module.make_tensordict_primer())
model.apply(make_primers)
if not primers:
raise smt
elif len(primers) == 1:
return primers[0]
else:
return Compose(*primers)
Motivation
When RNN’s are used in isolation, creating a TensorDictPrimer Transform for the environment to populate the TensorDicts with the expected tensors is pretty straightforward:
However, when RNN’s are part of a larger architecture, this can become tricky. e.g.
In case you know the architecture, it is still possible to do:
But this is not ideal. Besides, beyond creating the transform automatically, maybe the user is interested in knowing the required shapes and other information of the model inputs, which now has the RNN inputs and their own inputs.
Solution
A solution would be to make possible to access all the information about the model expected inputs and outputs from some model specs.
Maybe it should not be required to define specs during the creation of the model, but optionally adding input specs would facilitate creating the primer transform in these cases.
Checklist