pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.24k stars 295 forks source link

[Feature Request] Decouple target net updater from loss functions #1401

Closed smorad closed 1 year ago

smorad commented 1 year ago

Motivation

Currently, the target net updaters require the use of a loss function. It would be cool if we could wrap an arbitrary TensorDictModule in a SoftUpdateModule. For complex policies, this would enable different parts of the policy to update at different rates. In my opinion, this is also slightly more flexible as you do not need to attach your module to a loss function. Another use case could be using soft updates on a policy-gradient policy, which is not currently possible.

Solution

One possible example:

pre = TensorDictModule(lstm, in_keys=['observation'], out_keys=['markov_state'])
pre_target = TensorDictModule(pre, eps=0.99)
post = Sequential(mlp, QValueModule())
post_target = SoftUpdateModule(post, eps=0.9)
loss_fn = DQNLoss(post, post_target)
...
tensordict = pre(tensordict)
loss = loss_fn(tensordict)
pre_target.step(tensordict.numel())
post_target.step(tensordict.numel())

The SoftUpdate module itself should not be too difficult, it could look something like:

class SoftUpdateModule(TensorDictModule):
  def __init__(self, module, ...):
    self.module = module
    self.target_module = copy.deepcopy(self.module)

  def forward(self, td):
    return self.target_module(td)

  def step(self, steps):
    for p, q in zip(self.module.parameters() self.target_parameters()):
      q.data.set_(self.eps * q.data + (1 - self.eps) * p.data) 

Alternatives

I want to use an LSTM to preprocess a tensordict outside of a policy, but to also have a target LSTM network. I'm not sure how to accomplish this in a different way.

Additional context

This might make more sense once we utilize the EnsembleModule class in TD3/SAC etc. I suspect the updaters are coupled with the loss functions because the loss functions are currently responsible for duplicating params for ensembles.

IMO it's more flexible to allow the user to pass network and target networks to loss functions than the current delay_value=True flag. For example, the target network could be a compressed/distilled copy of the main network weights.

Checklist

vmoens commented 1 year ago

The problem that I see is that currently the loss module takes care of creating the target params for you (something that may be tedious for newcomers). But, because the target params are structured using tensordict, it is easy to isolate them:

lstm_target = loss_module.target_params["path", "to", "lstm"]

assuming that your module looks like lstm = module.path.to.lstm.

We could use that to create custom updaters with an "in_keys" kind of arg:

lstm_updater = SoftUpdate(loss_module, sub_module=("path", "to", "lstm"))
mlp_updater = SoftUpdate(loss_module, sub_module=("path", "to", "mlp"))

If no sub_module (name to be chosen) is provided, the updater takes all the target params.

smorad commented 1 year ago

Hmm, that would work in many cases but I'm not sure it would for me. I'm trying to do my own batching/masking for recurrent models

lstm = LSTMModule(...)
post = Seq(MLP(...), QValueModule())
loss_fn = DQNLoss(post)
collector = SyncDataCollector(
    env,
    Seq(lstm, post),
    split_trajs=True,
)
segment_length = 200
for data in collector:
  padded = tensordict.pad(data, [0, 0, 0, segment_length - data.shape[-1]])
  buffer.extend(padded)
  batch = buffer.sample()
  batch = lstm(batch)
  unpadded = batch.masked_select(batch[("collector", "mask")])
  loss = loss_fn(unpadded)
  ...

I'm aware I don't have to do this for the built in LSTM model, but I'd like the freedom to tinker with R2D2 tools like burn-in, fixed-size segments, etc. In that case, I think LSTM cannot be part of the loss function as I need to do some unpadding/masking first.

vmoens commented 1 year ago

Got it So you handle the target params yourself? Where are they used in this code example? If you have your target params and you handle them yourself It is feasible to create a custom updater for these cases. The reason we did not do it is because we thought that the loss module was generic enough to handle any model.

Along the same line of thought: Can't you wrap all the extra ops in a a TensorDictModule(lambda x: foo(x), ...) and flank that with your LSTM to pass it to the loss module?

smorad commented 1 year ago

Where are they used in this code example?

It's a bit ugly, but I was thinking of doing something like

lstm = LSTMModule(lstm=lstm, out_key='markov_state'))
target_lstm = LSTMModule(lstm=deepcopy(lstm), out_key=('next', 'markov_state'))
for data in collector:
  ...
  lstm(batch) # set markov state
  target_lstm(batch) # set next markov state w/o grad
...

Along the same line of thought: Can't you wrap all the extra ops in a a TensorDictModule(lambda x: foo(x), ...) and flank that with your LSTM to pass it to the loss module?

Ooh that sounds like a very cool trick! I'll see if that works. If so, I can upstream a MaskModule if you think that would be useful.