Open matteobettini opened 1 year ago
So in the algorithm you linked, I believe you need to recompute targets for each minibatch because you are using polyak averaging (line 15). So the targets should change slightly with each update.
That said, I think separation between target computation and value loss would be pretty cool for implementing custom algorithms. Then you could override DDPG/SAC do whatever harebrained target scheme you want (e.g. TQC, etc.). In general, I think it would be beneficial to have more decoupling in the loss functions. That way, we can try small changes without rewriting SACLoss.forward
and introducing potential bugs.
Yes, ideally it would be nice if each loss could have line 12,13, and 14 separated into 3 functions by default:
compute_target
loss_actor
(implemented only when an actor is present)loss_value
So that, as you say, users can override just one of these to customize. It would also help separate the gradients.
I agree that's a feature we can work on And it can be done on a loss-by-loss basis
class SACLoss(...):
def forward(self, td):
# some preproc
value_loss = self._value_loss(td_preproc)
actor_loss = self._actor_loss(td_preproc)
return TD({...}, []) # as it is now
def actor_loss(td):
# some preproc
actor_loss = self._actor_loss(td_preproc)
return actor_loss
def value_loss(td):
# some preproc
value_loss = self._value_loss(td_preproc)
return value_loss
Would this work?
@vmoens We still want to keep the target computation separate tho (the goal of this issue). It would be
class SACLoss(...):
def forward(self, td):
td_preproc = self.compute_target(td)
value_loss = self.value_loss(td_preproc)
actor_loss = self.actor_loss(td_preproc)
return TD({...}, []) # as it is now
def compute_target(td): # user can override, skipped if the target has been already computed outside the loss
# does line 12 and calls the value_estimator
def actor_loss(td): # user can override
# computes actor loss
return actor_loss
def value_loss(td): # user can override
# computes vlaue loss
return value_loss
I think also this could be applyable to all losses
What about value_objective
because the term target is overloaded? Value target could refer to either the target network, or the r + gamma * q(s', a')
objective (which could be computed using the non-target network if delay_value==False
).
I do not like the idea of branching logic based on whether the value objective key is present in the dict, because it makes it unclear to the user whether or not the value_objective
function will be called. There is no way short of dropping a debug/print statement into the library to know if the key will exist in the tensordict before forward
. I propose that forward
should always call value_objective, actor_loss, value_loss
.
I do not like the idea of branching logic based on whether the value objective key is present in the dict, because it makes it unclear to the user whether or not the value_objective function will be called. There is no way short of dropping a debug/print statement into the library to know if the key will exist in the tensordict before forward. I propose that forward should always call value_objective, actor_loss, value_loss.
Yep I mostly agree.
Here's the refactoring of PPO I had in mind since with it, it is common to compute the value first and then re-use it for several iteration over subsamples of the batch collected:
>>> # case 1: PPO loss is called without setting GAE first, and GAE has not been set
>>> ppo_loss = PPOLoss(...)
>>> for batch in collector:
... rb.extend(batch)
... for subbatch in rb:
... ppo_loss(subbatch)
WARGNING("you have not set the value estimator in PPO. To suppress this warning, call ppo_loss.make_value_estimator(). For more information, check the doc of the make_value_estimator method")
>>> # case 2: set the value estimator
>>> ppo_loss = PPOLoss(...)
>>> ppo_loss.make_value_estimator(**maybe_some_kwargs)
>>> for batch in collector:
... rb.extend(batch)
... for subbatch in rb:
... ppo_loss(subbatch) # no warning
>>> # case 3: separate GAE calls
>>> ppo_loss = PPOLoss(...)
>>> ppo_loss.make_value_estimator(**maybe_some_kwargs)
>>> for batch in collector:
... ppo_loss.value_objective(batch)
... rb.extend(batch)
... for subbatch in rb:
... ppo_loss(subbatch) # no warning
We would specifically details both approaches in the PPO docstring. The same logic would apply for other on-policy algos.
For some other losses such as SAC, computing the target occurs within the loss and it's not easy to think of what that would look like when called from the outside. For DQN I guess the situation would be much clearer though.
Would that work?
@vmoens that looks good to me! Why do you say that in SAC it would not be seperable? In SAC value_objective() would run line 12. It seems seperable in all losses to me.
@smorad we need a way to precompute targets (objectives). If you have a better way we can discuss but a way is needed imo
@matteobettini whats wrong with calling value_objective
in forward?
def forward(self, td):
# Write values to tensordict
self.value_objective(td)
self.value_loss(td)
...
Users can also call it however they want by overriding forward
, or calling value_objective, value_loss, policy_loss
in some user-defined function.
For some other losses such as SAC, computing the target occurs within the loss and it's not easy to think of what that would look like when called from the outside. For DQN I guess the situation would be much clearer though.
Would that work?
Sorry, for SAC, why couldn't the target computation just be moved from forward
to value_objective
? _get_value_v2
seems to be exactly this, no? Is it because _get_value_v2
requires access to the policy to compute the log probs/entropy loss term?
Currently the losses in torchrl compute the value target when forward is called on them.
The problem is that if the loss is called on minibatches, the target will be computed each time for each menibatch.
This is extremely inefficient as targets can be precomputed at the beginning of the training iteration
I am proposing that losses should all have a separate function
loss.compute_value_target(tensordict)
which writes the target to the tensordictthe forward function of the loss module will then check if the target is present and, if not, will call
loss.compute_value_target(tensordict)
.Furthermore, in this restructuring, value estimators would be made independent of neural networks and jsut assume that they are given a tensordict with all the desired keys and write in it the new keys.
This whole update would allow to do something like
Where line 12 is currently not possible in torch rl
EDIT: This would also provide a better separation for gradiant operations (like the actual loss forward) and gradient stop operations (like
loss_module.compute_value_target(batch)
).It will also unify all losses since nowthey can all compute the targets outside of the foward call (notn just ppo and a few others)