pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.32k stars 307 forks source link

[Feature Request] Decouple losses from targets #1000

Open matteobettini opened 1 year ago

matteobettini commented 1 year ago

Currently the losses in torchrl compute the value target when forward is called on them.

The problem is that if the loss is called on minibatches, the target will be computed each time for each menibatch.

This is extremely inefficient as targets can be precomputed at the beginning of the training iteration

I am proposing that losses should all have a separate function loss.compute_value_target(tensordict) which writes the target to the tensordict

batch = rb.sample(60_000)
loss_module.compute_value_target(batch)
for _ in range(n):
    minibatch = subsample(batch, 1000)
    loss_vals = loss_module(minibatch)

the forward function of the loss module will then check if the target is present and, if not, will call loss.compute_value_target(tensordict).

Furthermore, in this restructuring, value estimators would be made independent of neural networks and jsut assume that they are given a tensordict with all the desired keys and write in it the new keys.

This whole update would allow to do something like

Screenshot 2023-03-29 at 09 39 58

Where line 12 is currently not possible in torch rl

EDIT: This would also provide a better separation for gradiant operations (like the actual loss forward) and gradient stop operations (like loss_module.compute_value_target(batch)).

It will also unify all losses since nowthey can all compute the targets outside of the foward call (notn just ppo and a few others)

smorad commented 1 year ago

So in the algorithm you linked, I believe you need to recompute targets for each minibatch because you are using polyak averaging (line 15). So the targets should change slightly with each update.

That said, I think separation between target computation and value loss would be pretty cool for implementing custom algorithms. Then you could override DDPG/SAC do whatever harebrained target scheme you want (e.g. TQC, etc.). In general, I think it would be beneficial to have more decoupling in the loss functions. That way, we can try small changes without rewriting SACLoss.forward and introducing potential bugs.

matteobettini commented 1 year ago

Yes, ideally it would be nice if each loss could have line 12,13, and 14 separated into 3 functions by default:

So that, as you say, users can override just one of these to customize. It would also help separate the gradients.

vmoens commented 1 year ago

I agree that's a feature we can work on And it can be done on a loss-by-loss basis

class SACLoss(...):
    def forward(self, td):
        # some preproc
        value_loss = self._value_loss(td_preproc)
        actor_loss = self._actor_loss(td_preproc)
        return TD({...}, []) # as it is now
     def actor_loss(td):
        # some preproc
        actor_loss = self._actor_loss(td_preproc)
        return actor_loss
     def value_loss(td):
        # some preproc
        value_loss = self._value_loss(td_preproc)
        return value_loss

Would this work?

matteobettini commented 1 year ago

@vmoens We still want to keep the target computation separate tho (the goal of this issue). It would be

class SACLoss(...):
    def forward(self, td):
        td_preproc = self.compute_target(td)
        value_loss = self.value_loss(td_preproc)
        actor_loss = self.actor_loss(td_preproc)
        return TD({...}, []) # as it is now
     def compute_target(td):  # user can override, skipped if the target has been already computed outside the loss
         # does line 12 and calls the value_estimator
     def actor_loss(td): # user can override
        # computes actor loss
        return actor_loss
     def value_loss(td): # user can override
        # computes vlaue loss
        return value_loss
matteobettini commented 1 year ago

I think also this could be applyable to all losses

smorad commented 1 year ago

What about value_objective because the term target is overloaded? Value target could refer to either the target network, or the r + gamma * q(s', a') objective (which could be computed using the non-target network if delay_value==False).

I do not like the idea of branching logic based on whether the value objective key is present in the dict, because it makes it unclear to the user whether or not the value_objective function will be called. There is no way short of dropping a debug/print statement into the library to know if the key will exist in the tensordict before forward. I propose that forward should always call value_objective, actor_loss, value_loss.

vmoens commented 1 year ago

I do not like the idea of branching logic based on whether the value objective key is present in the dict, because it makes it unclear to the user whether or not the value_objective function will be called. There is no way short of dropping a debug/print statement into the library to know if the key will exist in the tensordict before forward. I propose that forward should always call value_objective, actor_loss, value_loss.

Yep I mostly agree.

Here's the refactoring of PPO I had in mind since with it, it is common to compute the value first and then re-use it for several iteration over subsamples of the batch collected:

>>> # case 1: PPO loss is called without setting GAE first, and GAE has not been set
>>> ppo_loss = PPOLoss(...)
>>> for batch in collector:
...     rb.extend(batch)
...     for subbatch in rb:
...         ppo_loss(subbatch)
WARGNING("you have not set the value estimator in PPO. To suppress this warning, call ppo_loss.make_value_estimator(). For more information, check the doc of the make_value_estimator method")
>>> # case 2: set the value estimator
>>> ppo_loss = PPOLoss(...)
>>> ppo_loss.make_value_estimator(**maybe_some_kwargs)
>>> for batch in collector:
...     rb.extend(batch)
...     for subbatch in rb:
...         ppo_loss(subbatch) # no warning
>>> # case 3: separate GAE calls
>>> ppo_loss = PPOLoss(...)
>>> ppo_loss.make_value_estimator(**maybe_some_kwargs)
>>> for batch in collector:
...     ppo_loss.value_objective(batch)
...     rb.extend(batch)
...     for subbatch in rb:
...         ppo_loss(subbatch) # no warning

We would specifically details both approaches in the PPO docstring. The same logic would apply for other on-policy algos.

For some other losses such as SAC, computing the target occurs within the loss and it's not easy to think of what that would look like when called from the outside. For DQN I guess the situation would be much clearer though.

Would that work?

matteobettini commented 1 year ago

@vmoens that looks good to me! Why do you say that in SAC it would not be seperable? In SAC value_objective() would run line 12. It seems seperable in all losses to me.

@smorad we need a way to precompute targets (objectives). If you have a better way we can discuss but a way is needed imo

smorad commented 1 year ago

@matteobettini whats wrong with calling value_objective in forward?

def forward(self, td):
  # Write values to tensordict
  self.value_objective(td)
  self.value_loss(td)
  ...

Users can also call it however they want by overriding forward, or calling value_objective, value_loss, policy_loss in some user-defined function.

smorad commented 1 year ago

For some other losses such as SAC, computing the target occurs within the loss and it's not easy to think of what that would look like when called from the outside. For DQN I guess the situation would be much clearer though.

Would that work?

Sorry, for SAC, why couldn't the target computation just be moved from forward to value_objective? _get_value_v2 seems to be exactly this, no? Is it because _get_value_v2 requires access to the policy to compute the log probs/entropy loss term?