pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2k stars 268 forks source link

[BUG] Numerical Instability issues with `torchrl.modules.TanhNormal` #2199

Open N00bcak opened 1 month ago

N00bcak commented 1 month ago

Describe the bug

When training on PettingZoo/MultiWalker-v9 with Multi-Agent Soft Actor-Critic, all losses (loss_actor, loss_qvalue, loss_alpha) explode after ~1M environment steps at most.

This phenomenon occurs regardless of (reasonable) hyperparameter and gradient clipping threshold choice.

To Reproduce

from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box

import logging
import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage 

from torchrl.envs import (
    check_env_specs,
    PettingZooEnv, 
    ParallelEnv,
    GymEnv
)

from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import (
    MultiAgentMLP,
    MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector, RandomPolicy

from torchrl.objectives import SACLoss, SoftUpdate

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv, Compose, Transform, RewardSum, ObservationNorm, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

import multiprocessing as mp

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    '''
    This is an MLP policy network for MultiAgent SAC.

    This is just a more limited version of MultiAgentMLP.
    (https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html)
    '''
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    # Copied over from MultiAgentMLP.
    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        # If the model is centralized, agents have full observability
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net

    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents

        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device) # Bandaid fix to use MultiSyncDataCollector

        model = self.init_net_params(model)

        return model

class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)  

# Main Function
if __name__ == "__main__":
    logging.basicConfig(level = logging.INFO)
    logger = logging.getLogger(__name__)
    logger.propagate = False
    logger.addHandler(TqdmLoggingHandler())

    mp.set_start_method("spawn", force = True)

    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 8
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 256
    LR = 1e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = 0 #int(2e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 1 #int(9e4 // EXPLORATION_STEPS) # Every 500k steps or so, evaluate once.

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    # https://pytorch.org/rl/stable/tutorials/multiagent_competitive_ddpg.html
    # More tutorials: https://pytorch.org/tutorials/advanced/pendulum.html
    # Toy test: https://pettingzoo.farama.org/environments/sisl/multiwalker/def env_fn(mode, parallel = True):
        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = -5.0,
                                    forward_reward = 1.0,
                                    fall_reward = -1.0,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = "cpu"
                                )

        if parallel:
            # Don't use.
            # https://discuss.pytorch.org/t/pettingzoo-trouble-running-multiple-marl-environments-in-parallel/203706/

            env = lambda: ParallelEnv(num_workers = 4,  # noqa: E731
                                        create_env_fn = base_env_fn, 
                                        device = "cpu",
                                        mp_start_method = "spawn",
                                        serial_for_single = True
                                    )
        else:
            env = base_env_fn # noqa: E731

        def env_with_transforms():
            # dummy_env = base_env_fn()
            # dummy_obs_transform = ObservationNorm(in_keys = [("walker", "observation")], standard_normal = True)
            # dummy_env = TransformedEnv(dummy_env, dummy_obs_transform)
            # dummy_obs_transform.init_stats(10000)

            init_env = env()
            # obs_transform = ObservationNorm(loc = dummy_obs_transform.loc + EPS,
            #                                 scale = dummy_obs_transform.scale + EPS,
            #                                 in_keys = [("walker", "observation")], 
            #                                 standard_normal = True
            #                             )
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                            # obs_transform
                                        )
                                    )
            # del dummy_env, dummy_obs_transform
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()
    eval_env = env_fn("rgb_array", parallel = False)()
    video_recorder = VideoRecorder(
                                    CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                    tag = "rendered", 
                                    in_keys = ["pixels_record"]
                                )

    # Call the parent's render function
    eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
    eval_env.append_transform(video_recorder)

    if eval_env.is_closed:
        eval_env.start()

    check_env_specs(train_env)
    check_env_specs(eval_env)

    print(f"Action: {train_env.full_action_spec}, Reward: {train_env.full_reward_spec}, Done: {train_env.full_done_spec}, Observation: {train_env.full_observation_spec}")

    print(f"group_map: {train_env.group_map}")
    print(f"Action: {train_env.action_keys}, Reward: {train_env.reward_keys}, Done: {train_env.done_keys}")

    # NOTE: The input and output spaces to be fed in are on a PER-AGENT basis.
    # Basically, if you have 16 agents observing 3D velocity and outputting speed (the magnitude), 
    # n_agent_inputs = 3, n_agent_outputs = 1.  
    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False, 
                          share_params = True, 
                          device = "cpu",
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    )

    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True,
                          share_params = True, 
                          device = "cpu",
                          activation_class = nn.LeakyReLU, 
                        )

    # Hook our networks to TensorDictModules so they can be a part of the TensorDict pipeline...
    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            # NOTE: These outputs must match with the parameter names of the 
                                            # distribution you are using!
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )

    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        # TanhNormal is based off of pytorch, which as far as we know, 
        # implements a numerically stable log det jacobian.
        distribution_class = TanhNormal,
        distribution_kwargs = {
            "min": train_env.full_action_spec["walker", "action"].space.low,
            "max": train_env.full_action_spec["walker", "action"].space.high,
        },
        return_log_prob = True,
    )

    with torch.no_grad():
        fake_td = train_env.fake_tensordict()
        policy_actor(fake_td)

    dora = AdditiveGaussianWrapper(
        policy = policy_actor,
        action_key = ("walker", "action"),
        sigma_init = 0.3,
        sigma_end = 0.1,
        annealing_num_steps = TRAIN_TIMESTEPS // 2
    )

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        ) 

    collector = MultiSyncDataCollector(
                    [env_fn(None, parallel = False) for _ in range(NUM_EXPLORE_WORKERS)], 
                    policy = dora,
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = 0,
                    total_frames = TRAIN_TIMESTEPS,
                    device = "cpu",
                    reset_at_each_iter = False
                )

    replay_buffer = TensorDictReplayBuffer(
        storage = LazyMemmapStorage(
            REPLAY_BUFFER_SIZE, device = "cpu",
        ),  # We will store up to memory_size multi-agent transitions
        sampler = RandomSampler(),
        batch_size = BATCH_SIZE,  # We will sample batches of this size
    )

    sac_loss = SACLoss(policy_actor.to(DEVICE), 
                        qvalue_network = critic_actor.to(DEVICE), 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_qvalue = True,
                        alpha_init = 0.1
                        )
    sac_loss.set_keys(
        action = ("walker", "action"),
        state_action_value = ("walker", "state_action_value"),
        reward = ("walker", "reward"),
        done = ("walker", "done"),
        terminated = ("walker", "terminated"),
    )

    sac_loss.make_value_estimator(gamma = VALUE_GAMMA)

    polyak_updater = SoftUpdate(sac_loss, tau = 0.005) 

    critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
    actor_params = list(sac_loss.actor_network_params.flatten_keys().values())

    optimizer_actor = torch.optim.Adam(
        actor_params,
        lr = LR,
        weight_decay = 5e-4,
        eps = EPS,
        betas = (0.9, 0.98)
    )
    optimizer_critic = torch.optim.Adam(
        critic_params,
        lr = LR,
        weight_decay = 5e-4,
        eps = EPS,
        betas = (0.9, 0.98)
    )
    optimizer_alpha = torch.optim.Adam(
        [sac_loss.log_alpha],
        lr = LR,
        eps = EPS,
        betas = (0.9, 0.98)
    )

    # breakpoint()
    num_frames = 0
    pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
    total_frames = 0
    backprop_ctr = 0
    train_rews, ep_lengths = [], []
    EXPLORATION_BATCHES = EXPLORATION_STEPS // BATCH_SIZE
    for i, tensordict in enumerate(collector):

        collector.update_policy_weights_()

        pbar.update(tensordict.numel())

        tensordict = tensordict.reshape(-1)
        current_frames = tensordict.numel()
        # Add to replay buffer
        replay_buffer.extend(tensordict.cpu())
        total_frames += current_frames

        backprop_ctr += 1
        # Optimization steps
        if total_frames >= WARMUP_STEPS and backprop_ctr > EXPLORATION_BATCHES:
            backprop_ctr = 0
            losses = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
            alphas = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
            for j in range(UPDATE_STEPS_PER_EXPLORATION):
                # Sample from replay buffer
                sampled_tensordict = replay_buffer.sample()

                if str(sampled_tensordict.device) != DEVICE:
                    sampled_tensordict = sampled_tensordict.to(DEVICE, non_blocking = False)
                else:
                    sampled_tensordict = sampled_tensordict.clone()

                try:
                    # Compute loss
                    loss_td = sac_loss(sampled_tensordict)
                except KeyError:
                    raise Exception(f"wtf {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")

                actor_loss = loss_td["loss_actor"]
                q_loss = loss_td["loss_qvalue"]
                alpha_loss = loss_td["loss_alpha"]

                # Update actor
                optimizer_actor.zero_grad()
                actor_loss.backward()
                actor_grad_norm = torch.nn.utils.clip_grad_norm_(actor_params, max_norm = MAX_GRAD_NORM)
                optimizer_actor.step()

                # Update critic
                optimizer_critic.zero_grad()
                q_loss.backward()
                q_grad_norm = torch.nn.utils.clip_grad_norm_(critic_params, max_norm = MAX_GRAD_NORM)
                optimizer_critic.step()

                # Update alpha
                optimizer_alpha.zero_grad()
                alpha_loss.backward()
                alpha_grad_norm = torch.nn.utils.clip_grad_norm_([sac_loss.log_alpha], max_norm = MAX_GRAD_NORM)
                optimizer_alpha.step()

                losses[j] = loss_td.select(
                    "loss_actor", "loss_qvalue", "loss_alpha"
                ).detach()

                alphas[j] = loss_td.select("alpha")

                # Update qnet_target params
                polyak_updater.step()

            # Some other stuff I ripped out from https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py
            episode_end = (
                tensordict["next", "done"]
                if tensordict["next", "done"].any()
                else tensordict["next", "truncated"]
            )
            opening_banner = "-" * 10 + f" Batch {i + 1} " + "-" * 10

            def get_mean(src, key): 
                return src.get(key).mean().item()

            logger.info(opening_banner)
            logger.info(f"Average Actor Loss: {get_mean(losses, 'loss_actor')}")
            logger.info(f"Average Q Loss: {get_mean(losses, 'loss_qvalue')}")
            logger.info(f"Average Alpha: {get_mean(alphas, 'alpha')} (Loss: {get_mean(losses, 'loss_alpha')})")
            logger.info("-" * len(opening_banner))

            ep_length = tensordict['next', 'step_count'][episode_end].to(dtype = torch.float64)

            if ep_length.numel():
                ep_lengths.append(ep_length.mean().item())

        agent_terminated = torch.stack(
            [
                tensordict["next", "walker", "done"][:, agent_id, 0]
                if tensordict["next", "walker", "done"][:, agent_id, 0].any()
                else tensordict["next", "walker", "truncated"][:, agent_id, 0]
                for agent_id in range(NUM_AGENTS)
            ], 
            dim = 1
        )
        train_reward = tensordict['next', 'walker', 'episode_reward'][agent_terminated]

        if train_reward.numel():
            train_rews.append(train_reward.mean().item())

        if not ((i + 1) % (EVAL_INTERVAL * EXPLORATION_BATCHES)):
            logger.info(
                        f"Mean Train Reward Across Past {EVAL_INTERVAL} Collections: " +
                        (
                            f"{sum(train_rews) / len(train_rews)}" 
                            if len(train_rews) 
                            else f"NA (Training starts @ {WARMUP_STEPS} steps)"
                        )
                    )
            with set_exploration_type(ExplorationType.MODE), torch.no_grad():
                eval_rollout = eval_env.rollout(
                    MAX_EPISODE_STEPS,
                    policy_actor,
                    auto_cast_to_device=True,
                    break_when_any_done=True,
                )

                mean_eval_length = eval_rollout["next", "step_count"][-1].to(dtype = torch.float64).mean().item()
                logger.info(f"Mean Eval Reward: {eval_rollout['next', 'walker', 'episode_reward'][-1].mean().item()}")
                logger.info(f"Eval Length: {mean_eval_length}")
            ep_reward_list = []
            train_rews = []
            eval_env.transform.dump()

    collector.shutdown()
    train_env.close()

Expected behavior

Loss values stay within ~ +/- 10^2 throughout training and do not increase to ~ +/- 10^x where x >> 1.

System info

>>> import torchrl, numpy, sys
>>> print(f"TorchRL: {torchrl.__version__}\nNumPy: {numpy.__version__}\nPython3 Ver: {sys.version}\nPlatform: {sys.platform}")
TorchRL: 0.4.0 
NumPy: 1.25.0 
Python3 Ver: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] 
Platform: linux
> lsb_release -a
Distributor ID: Ubuntu
Description:    Ubuntu 22.04.3 LTS
Release:    22.04
Codename:   jammy

Reason and Possible fixes

Though the environment's observation space is not normalized and carries unbounded entries, the issue does not appear to entirely arise from the poor observation scaling, since adding a torchrl.envs.ObservationNorm does not mitigate the issue.

Debugging reveals that unusually large and negative values for log_prob are somehow being fed into the SACLoss calculations from the reimplementation of torch.distributions.transforms.TanhTransform. https://github.com/pytorch/rl/blob/3e6cb8419df56d9263d1daa48f9c3be5f01eaea6/torchrl/modules/distributions/continuous.py#L289-L382

Since this reimplementation does not change much from the original TanhTransform, it is plausible that the reimplementation is NOT the root cause of the error. Nevertheless, replacing the reimplementation with an alternative variant gets rid of the issue altogether:


class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # from stable_baselines3's `common.distributions.TanhBijector`
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # From PyTorch `TanhTransform`
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''

        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS)) # <- **CLAMPING THIS SEEMS TO RESOLVE THE ISSUE**
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

although such a fix flies in the face of this comment from the PyTorch devs.

Checklist

vmoens commented 1 month ago

Any chance this is solved by #2198? If so let's redirect the discussion to #2186

matteobettini commented 1 month ago

I don't think this issue relates to the mode or the mean of the distribution (as I think those are not used in SAC, but I could be wrong).

The logp seems to be the core of these instabilities. I also experienced that in the past. Clamping tricks are helpful but we have to be careful on how we do this. I would suggest looking around at how others implent this and see what works best while still being a bit mathmatically grounded.

For example this is rllib's implementation, with some arbitrary constants in the code https://github.com/ray-project/ray/blob/e6e21ac2bba8b88c66c88b553a40b21a1c78f0a4/rllib/models/torch/torch_distributions.py#L275-L284

matteobettini commented 1 month ago

This is stable baseline's

    def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
        # Inverse tanh
        # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
        # We use numpy to avoid numerical instability
        if gaussian_actions is None:
            # It will be clipped to avoid NaN when inversing tanh
            gaussian_actions = TanhBijector.inverse(actions)

        # Log likelihood for a Gaussian distribution
        log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
        # Squash correction (from original SAC implementation)
        # this comes from the fact that tanh is bijective and differentiable
        log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
        return log_prob

Very similar to rllib's but without the intemidiate clamping trick.

vmoens commented 1 month ago

OK got it I played a lot with Tanh transform back in the days and the TLDR is that anything you do (clamp or no clamp) will degrade performance for someone. What about giving the option to use the "safe" tanh (with clamping) or not? Another option is: cast values from float32 to float64, do the tanh, cast back to float32. This could also be controlled via a flag in the TanhNormal constructor.

>>> x = torch.full((1,), 10.0)
>>> x.tanh().atanh()
tensor([inf])
>>> x.double().tanh().atanh().float()
tensor([10.])

Note that in practice this is unlikely to help in many cases, since casting to float after tanh() still screws up everything:

>>> x.double().tanh().float().double().atanh().float()
tensor([inf])
matteobettini commented 1 month ago

I like the idea of letting the user choose between the mathematically pure and the empirically more stable version with a flag. I wouldn't call it safe maybe as this is already used in other contexts, what about clamp_logp