Open N00bcak opened 1 month ago
Any chance this is solved by #2198? If so let's redirect the discussion to #2186
I don't think this issue relates to the mode or the mean of the distribution (as I think those are not used in SAC, but I could be wrong).
The logp
seems to be the core of these instabilities. I also experienced that in the past. Clamping tricks are helpful but we have to be careful on how we do this. I would suggest looking around at how others implent this and see what works best while still being a bit mathmatically grounded.
For example this is rllib's implementation, with some arbitrary constants in the code https://github.com/ray-project/ray/blob/e6e21ac2bba8b88c66c88b553a40b21a1c78f0a4/rllib/models/torch/torch_distributions.py#L275-L284
This is stable baseline's
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
return log_prob
Very similar to rllib's but without the intemidiate clamping trick.
OK got it I played a lot with Tanh transform back in the days and the TLDR is that anything you do (clamp or no clamp) will degrade performance for someone. What about giving the option to use the "safe" tanh (with clamping) or not? Another option is: cast values from float32 to float64, do the tanh, cast back to float32. This could also be controlled via a flag in the TanhNormal constructor.
>>> x = torch.full((1,), 10.0)
>>> x.tanh().atanh()
tensor([inf])
>>> x.double().tanh().atanh().float()
tensor([10.])
Note that in practice this is unlikely to help in many cases, since casting to float after tanh() still screws up everything:
>>> x.double().tanh().float().double().atanh().float()
tensor([inf])
I like the idea of letting the user choose between the mathematically pure and the empirically more stable version with a flag. I wouldn't call it safe maybe as this is already used in other contexts, what about clamp_logp
Describe the bug
When training on
PettingZoo/MultiWalker-v9
withMulti-Agent Soft Actor-Critic
, all losses (loss_actor
,loss_qvalue
,loss_alpha
) explode after ~1M environment steps at most.This phenomenon occurs regardless of (reasonable) hyperparameter and gradient clipping threshold choice.
To Reproduce
Expected behavior
Loss values stay within
~ +/- 10^2
throughout training and do not increase to~ +/- 10^x
where x >> 1.System info
Reason and Possible fixes
Though the environment's observation space is not normalized and carries unbounded entries, the issue does not appear to entirely arise from the poor observation scaling, since adding a
torchrl.envs.ObservationNorm
does not mitigate the issue.Debugging reveals that unusually large and negative values for
log_prob
are somehow being fed into theSACLoss
calculations from the reimplementation oftorch.distributions.transforms.TanhTransform
. https://github.com/pytorch/rl/blob/3e6cb8419df56d9263d1daa48f9c3be5f01eaea6/torchrl/modules/distributions/continuous.py#L289-L382Since this reimplementation does not change much from the original
TanhTransform
, it is plausible that the reimplementation is NOT the root cause of the error. Nevertheless, replacing the reimplementation with an alternative variant gets rid of the issue altogether:although such a fix flies in the face of this comment from the PyTorch devs.
Checklist