pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.37k stars 314 forks source link

[Feature Request] Suggestion: Tutorial on Model Ensembling #876

Open btx0424 opened 1 year ago

btx0424 commented 1 year ago

Motivation

Model ensembling is appealing in the RL context with a range of use cases, e.g., critic ensembles and parallel inference of multiple agents with the same actor structure. And I believe the design of torchrl has considered model ensembling using functorch as an important feature (according to the doc and the design of TensorDictModule). However, currently, there is not much clue in the doc/examples/tutorials on what's the best or suggested practice to actually implement it.

Tutorials or examples on either use case would be helpful.

Solution

A comprehensive example or tutorial on how to leverage model ensembling to perform efficient training and inference.

One tricky thing here is how to perform parallel optimization with an ensembled functional model, as functorch did not provide any direct solution (yet). A viable approach I figured was to use torchopt's functional optimizer API and do something like

import torchopt
import functorch
from tensordict.nn import TensorDictModule, make_functional

actors = nn.ModuleList([
  TensorDictModule(MyActorModule() for _ in range(num_agents)
]).to(device)

# this was the functorch way 
# things go similarly with tensordict.nn.make_functional
(fmodel, params, buffers) = functorch.combine_state_for_ensemble(actors)
actor_opt = torchopt.adam(lr=cfg.lr)
actor_opt_states = functorch.vmap(actor_opt.init)(params)

# to perform an optimization step
def opt_step(batch, params, actor_opt_states):
  actor_loss = functorch.vmap(actor_loss_fn)(params, batch)
  grads = torch.autograd.grad(actor_loss, params)
  updates, actor_opt_states = functorch.vmap(actor_opt.update)(grads, actor_opt_state)
  params = torchopt.apply_updates(params, updates)
  return actor_loss, params, actor_opt_states

for batch in some_collected_data:
  _, params, actor_opt_states = opt_step(batch, params, actor_opt_states)

But I am very uncertain about the efficiency of the above code (also tried getting gradients using functorch.grad but found it to be slightly slower) and have been wondering what the best practice is, especially if we want to use torchrl.

Additional Context

Some additional examples on critic ensembling, e.g., double Q-functions would also be of great help.

Checklist

vmoens commented 1 year ago

This is a very good idea, thanks for the suggestion

smorad commented 1 year ago

I would also find this very helpful! convert_to_functional seems like a very good starting point, it's a shame this is only usable in LossModules. Would it make sense to have an TensorDictModuleEnsemble class that could be applied to any nn.Module or TensorDictModule?