ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.55k stars 5.7k forks source link

[rllib] SAC numerical instability #14878

Closed dHonerkamp closed 2 years ago

dHonerkamp commented 3 years ago

What is the problem?

SAC calculates the gaussian log probability based on clamped values, which can result in very large values if the tanh saturates and as a consequence result in exploding gradients (and thereafter NaN values).

The following is based on the calculations in actor_critic_loss from sac_torch_policy.py, which uses the TorchSquashedGaussian distribution: First we sample an action a ~ tanh(Normal(loc, std)). We then calculate the log_probability as follows:

    def logp(self, x: TensorType) -> TensorType:
        # Unsquash values (from [low,high] to ]-inf,inf[)
        unsquashed_values = self._unsquash(x)
        # Get log prob of unsquashed values from our Normal.
        log_prob_gaussian = self.dist.log_prob(unsquashed_values)
        # For safety reasons, clamp somehow, only then sum up.
        log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100)
        log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1)
        # Get log-prob for squashed Gaussian.
        unsquashed_values_tanhd = torch.tanh(unsquashed_values)
        log_prob = log_prob_gaussian - torch.sum(
            torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1)
        return log_prob

The issue here is in self._unsquash(x):

    def _unsquash(self, values: TensorType) -> TensorType:
        normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
                        1.0
        # Stabilize input to atanh.
        save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER,
                                         1.0 - SMALL_NUMBER)
        unsquashed = atanh(save_normed_values)
        return unsquashed

For large values the clipping of save_normed_values means that the unsquashed values are also clipped to a range [-7.2, 7.2]. But the agent can still learn a mean for the distribution that can be arbitrarily large. As a result the calculated log_probabilities quickly explode for means larger than 7.2.

Solution

Calculate log_prob_gaussian from the actually draw gaussian values, not the clipped ones. I.e. change the following lines in sac_torch_policy.py:

        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_tp1 = action_dist_class(
            model.get_policy_output(model_out_tp1), policy.model)
        policy_tp1 = action_dist_tp1.sample() if not deterministic else \
            action_dist_tp1.deterministic_sample()
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

to not use policy_t and policy_tp1 to calculate logp but rather the originally drawn samples from the gaussian before squashing them.

Ray version and other system information (Python version, TensorFlow version, OS): Version: 2.0.0.dev0

Reproduction (REQUIRED)

Exploding log probabilities for the mean of the distribution:

mean = 50
action_dist = TorchSquashedGaussian(torch.Tensor([mean, 0.1]), None)
recovered_mean = action_dist._unsquash(action_dist._squash(torch.Tensor([mean]))
action_dist.logp(recovered_mean)
>>tensor(-87.2919)

If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".

avnishn commented 2 years ago

Hi there @dHonerkamp,

Thanks for opening this issue!

The support of a tanh transformed distribution will always be [-1,1]. The tanh transformation is not a one to one transformation outside of the interval [-1,1], so this means that if you were to make the mean of your distribution 50, this value would be squashed via the tanh transformation.

In other words, you wont be able to recover any values outside of [-1,1] as this is undefined behavior for a tanh transformed distribution.

We should likely produce an error message in this library, just like torch does if I were to try to run your example, but I think that it is on our roadmap to deprecate our internal distributions in favor of the native torch and tf ones.

Here is the stack trace if I try to recreate your example using the native torch TanhTransform and Normal distribution.

In [4]: from torch.distributions import TransformedDistribution, TanhTransform, Normal

In [5]: basedist = Normal(50, 0.1)

In [6]: transforms = [TanhTransform()]

In [7]: newdist = TransformedDistribution(basedist, transforms)

In [10]: newdist.log_prob(torch.Tensor([50.]))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-4ce6bf5374d1> in <module>
----> 1 newdist.log_prob(torch.Tensor([50.]))

~/miniconda3/envs/rllib37/lib/python3.7/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
    136         """
    137         if self._validate_args:
--> 138             self._validate_sample(value)
    139         event_dim = len(self.event_shape)
    140         log_prob = 0.0

~/miniconda3/envs/rllib37/lib/python3.7/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    275         assert support is not None
    276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
    278 
    279     def _get_checked_instance(self, cls, _instance=None):

ValueError: The value argument must be within the support

Some older implementations of SAC implement your solution, but many, including the original implementation have moved to using TFP and torch.distributions, including the original SoftLearning implementation of the algorithm: https://github.com/rail-berkeley/softlearning/blob/13cf187cc93d90f7c217ea2845067491c3c65464/softlearning/policies/gaussian_policy.py

I'm going to go ahead and close this for now, and open a separate issue about deprecating our internally maintained distributions in favor of the TFP and torch.distributions implementations. Please comment and tag me if you have any questions :)

Thanks! @avnishn

CC @sven1977 @gjoliver