Closed dHonerkamp closed 2 years ago
Hi there @dHonerkamp,
Thanks for opening this issue!
The support of a tanh transformed distribution will always be [-1,1]. The tanh transformation is not a one to one transformation outside of the interval [-1,1], so this means that if you were to make the mean of your distribution 50, this value would be squashed via the tanh transformation.
In other words, you wont be able to recover any values outside of [-1,1] as this is undefined behavior for a tanh transformed distribution.
We should likely produce an error message in this library, just like torch does if I were to try to run your example, but I think that it is on our roadmap to deprecate our internal distributions in favor of the native torch and tf ones.
Here is the stack trace if I try to recreate your example using the native torch TanhTransform
and Normal
distribution.
In [4]: from torch.distributions import TransformedDistribution, TanhTransform, Normal
In [5]: basedist = Normal(50, 0.1)
In [6]: transforms = [TanhTransform()]
In [7]: newdist = TransformedDistribution(basedist, transforms)
In [10]: newdist.log_prob(torch.Tensor([50.]))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-10-4ce6bf5374d1> in <module>
----> 1 newdist.log_prob(torch.Tensor([50.]))
~/miniconda3/envs/rllib37/lib/python3.7/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
136 """
137 if self._validate_args:
--> 138 self._validate_sample(value)
139 event_dim = len(self.event_shape)
140 log_prob = 0.0
~/miniconda3/envs/rllib37/lib/python3.7/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
275 assert support is not None
276 if not support.check(value).all():
--> 277 raise ValueError('The value argument must be within the support')
278
279 def _get_checked_instance(self, cls, _instance=None):
ValueError: The value argument must be within the support
Some older implementations of SAC implement your solution, but many, including the original implementation have moved to using TFP and torch.distributions, including the original SoftLearning implementation of the algorithm: https://github.com/rail-berkeley/softlearning/blob/13cf187cc93d90f7c217ea2845067491c3c65464/softlearning/policies/gaussian_policy.py
I'm going to go ahead and close this for now, and open a separate issue about deprecating our internally maintained distributions in favor of the TFP and torch.distributions implementations. Please comment and tag me if you have any questions :)
Thanks! @avnishn
CC @sven1977 @gjoliver
What is the problem?
SAC calculates the gaussian log probability based on clamped values, which can result in very large values if the tanh saturates and as a consequence result in exploding gradients (and thereafter NaN values).
The following is based on the calculations in
actor_critic_loss
fromsac_torch_policy.py
, which uses theTorchSquashedGaussian
distribution: First we sample an actiona ~ tanh(Normal(loc, std))
. We then calculate the log_probability as follows:The issue here is in
self._unsquash(x)
:For large values the clipping of
save_normed_values
means that the unsquashed values are also clipped to a range [-7.2, 7.2]. But the agent can still learn a mean for the distribution that can be arbitrarily large. As a result the calculated log_probabilities quickly explode for means larger than 7.2.Solution
Calculate log_prob_gaussian from the actually draw gaussian values, not the clipped ones. I.e. change the following lines in
sac_torch_policy.py
:to not use
policy_t
andpolicy_tp1
to calculate logp but rather the originally drawn samples from the gaussian before squashing them.Ray version and other system information (Python version, TensorFlow version, OS): Version: 2.0.0.dev0
Reproduction (REQUIRED)
Exploding log probabilities for the mean of the distribution:
If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".