Open btx0424 opened 1 year ago
This is a very good idea, thanks for the suggestion
I would also find this very helpful! convert_to_functional
seems like a very good starting point, it's a shame this is only usable in LossModule
s. Would it make sense to have an TensorDictModuleEnsemble
class that could be applied to any nn.Module
or TensorDictModule
?
Motivation
Model ensembling is appealing in the RL context with a range of use cases, e.g., critic ensembles and parallel inference of multiple agents with the same actor structure. And I believe the design of torchrl has considered model ensembling using functorch as an important feature (according to the doc and the design of TensorDictModule). However, currently, there is not much clue in the doc/examples/tutorials on what's the best or suggested practice to actually implement it.
Tutorials or examples on either use case would be helpful.
Solution
A comprehensive example or tutorial on how to leverage model ensembling to perform efficient training and inference.
One tricky thing here is how to perform parallel optimization with an ensembled functional model, as functorch did not provide any direct solution (yet). A viable approach I figured was to use torchopt's functional optimizer API and do something like
But I am very uncertain about the efficiency of the above code (also tried getting gradients using
functorch.grad
but found it to be slightly slower) and have been wondering what the best practice is, especially if we want to use torchrl.Additional Context
Some additional examples on critic ensembling, e.g., double Q-functions would also be of great help.
Checklist