DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.69k stars 1.65k forks source link

log_std filled with NaNs when using PPO with use_sde=True #1593

Closed anirudhs001 closed 9 months ago

anirudhs001 commented 1 year ago

🐛 Bug

the log_std tensor gets filled completely with NaNs and causes a ValueError exception during training with PPO. have tried using both use_expln=True/False as mentioned in https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/146

The tensorboard graphs just before the exception: image

Code example

# The environment : 
import torch
import math
import gym
import numpy as np
import collections
from gym.envs.mujoco.ant_v4 import AntEnv
from gym import spaces
import time
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3 import SAC
import stable_baselines3
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_util import SubprocVecEnv
np.random.seed(43)

def vec_norm(vec):
    return sum([v ** 2 for v in vec]) ** 0.5

class ModAntEnv_V2(AntEnv):

    def __init__(self, render_mode=None) -> None:
        super().__init__(render_mode=render_mode)

        self.targ_dist = 0.1
        self.targ = 0.1, 0
        self.speed_targ, self.theta_targ, self.ang_speed = 1, 0, 0
        self.jnt_angles_repository = collections.deque(maxlen = 3)
        self.actions_repository = collections.deque(maxlen = 2)
        self.sensor_data_repository = collections.deque(maxlen = 3)
        self.jnt_names = ['ankle_1', 'ankle_2', 'ankle_3',
                       'ankle_4', 'hip_1', 'hip_2', 'hip_3', 'hip_4']
        self.jnt_adrs = [self.model.jnt(jnt_name).qposadr[0] for jnt_name in self.jnt_names]

        self.control = (1, 0)
        self.step_idx = 0

        # observation space with 8 joint angles with values in [-pi, pi]
        # 2 control values + 3 * size 8 joint angles + 3 * size 6 values from gyro and accelerometer + 2 * size 8 actions
        self.observation_space = spaces.Box(
            low=-np.pi, high=np.pi, shape=(2+24+18+16,), dtype=np.float32)
        self.seed = lambda x: 0

    def get_angles_from_matrix(self, rot_mat):
        '''
        returns angles in radians
        '''
        O_x = np.arctan2(rot_mat[3*2+1], rot_mat[3*2+2])
        O_y = np.arctan2(-rot_mat[3*2], np.sqrt(rot_mat[3*2+1]
                         ** 2+rot_mat[3*2+2]**2))
        O_z = np.arctan2(rot_mat[3*1], rot_mat[0])
        return (O_x, O_y, O_z)

    def get_jnt_angles(self):
        # return super()._get_obs()
        jnt_angles = []
        for adr in self.jnt_adrs:
            jnt_angles.append(self.data.qpos[adr])
        return np.array(jnt_angles)

    def get_sensor_data(self):
        # pdb.set_trace()
        gyro_data = self.data.sensor('gyro').data.copy()
        accel_data = self.data.sensor('accel').data.copy()
        return np.concatenate([gyro_data, accel_data])

    def get_curr_pos_and_angle(self):
        x, y, z = self.get_body_com("torso")[:3].copy()
        ori_mat = self.data.body("torso").xmat.copy()
        angles_curr = self.get_angles_from_matrix(ori_mat)
        alpha, beta, gamma = angles_curr
        return x, y, gamma

    def update_targ(self, control):
        '''
        control : (vel, rot) tuple 
            vel in [0,1]
            rot in {-1, 0, 1}
        '''
        return
        self.speed_targ, self.theta_targ, self.ang_speed = control
        self.theta_targ *= np.pi  # to radians
        curr_x, curr_y, curr_gamma = self.get_curr_state()
        targ_theta = curr_gamma + self.theta_targ
        targ_theta = targ_theta + \
            ((2 * np.pi) if targ_theta < -np.pi else 0) + \
            ((-2 * np.pi) if targ_theta > np.pi else 0)
        # distance is proportional to speed
        targ_x = curr_x + self.targ_dist * \
            self.speed_targ * math.cos(targ_theta)
        targ_y = curr_y + self.targ_dist * \
            self.speed_targ * math.sin(targ_theta)
        self.targ = (targ_x, targ_y)

    def custom_reward(self, state_pre, state_post, actions, weights=[1, 1e2, 1e3, 0, 0, 0]):
        vel_control, rot_control = self.control
        x_pre, y_pre, gamma_pre = state_pre
        x_post, y_post, gamma_post = state_post

        # Rewards
        # w_survive, w_dist, w_speed, w_dir, w_angular_speed, w_cost_energy = weights
        w_survive, w_rot, w_vel, w_cost_energy, w_jitter1, w_jitter2 = weights
        # reward for survival
        r_survive = 1
        r_survive *= w_survive

        # rotation reward
        theta_rot = gamma_post - gamma_pre
        theta_rot = theta_rot + 2*np.pi if theta_rot < -np.pi else theta_rot - 2*np.pi if theta_rot > np.pi else theta_rot
        r_rot = theta_rot * rot_control
        r_rot *= w_rot

        # distance reward
        dist = np.sqrt((x_post - x_pre) ** 2 + (y_post - y_pre) ** 2)
        motion_dir = np.arctan2(y_post - y_pre, x_post - x_pre)
        angle_diff = motion_dir - gamma_pre
        # dist_comp = dist * (math.e**np.cos(angle_diff))
        dist_comp = dist * np.cos(angle_diff)
        r_dist = dist_comp if rot_control == 0 else -dist * 1e-1
        r_dist *= w_vel * vel_control

        # energy cost
        cost_energy = 0.5 * np.sum(actions ** 2)
        cost_energy *= (1-vel_control) * w_cost_energy

        #action smoothness penalty
        cost_jitter = w_jitter1 * np.sum(np.abs(actions - self.actions_repository[-1]))**2 + \
            w_jitter2 * np.sum(np.abs(
                actions - 2*self.actions_repository[-1] + self.actions_repository[-2]))**2
        # total reward
        # distance travelled should be penalized 

        r = r_survive + r_rot + r_dist - cost_energy - cost_jitter
        r_comps = [r_survive, r_rot, r_dist, cost_energy, cost_jitter]
        return r, r_comps 

    def step(self, action, repeat_steps=2):
        self.step_idx += 1
        state_pre = self.get_curr_pos_and_angle()
        for _ in range(repeat_steps):
            self.do_simulation(action, self.frame_skip)
            is_done = self.terminated
            if is_done:
                break
        # obs = self._get_obs()
        state_post = self.get_curr_pos_and_angle()
        r, r_comps = self.custom_reward(state_pre, state_post, action)
        # print('state_pre: ', state_pre, 'state_post: ', state_post)
        # print('r_comps: ', r_comps)
        jnt_angles = self.get_jnt_angles()
        sensor_data = self.get_sensor_data()
        self.jnt_angles_repository.append(jnt_angles)
        self.actions_repository.append(action)
        self.sensor_data_repository.append(sensor_data)

        #concat jnt_angles_repository and actions_repositorys, both of which are collections.deque
        jnt_angles_flat = np.array(self.jnt_angles_repository).flatten()
        sensor_data_flat = np.array(self.sensor_data_repository).flatten()
        actions_flat = np.array(self.actions_repository).flatten()
        control = np.array(self.control)
        obs = np.concatenate((control, jnt_angles_flat, sensor_data_flat, actions_flat))

        # return obs, (r, r_comps), is_done, {'episode':None}
        return obs, r, is_done, {'episode':None}

    def reset(self, control=None):
        # clear both repositories
        self.jnt_angles_repository.clear()
        self.actions_repository.clear()
        self.step_idx = 0
        if control is not None:
            self.control = control
        else:
            #randomly sample control
            rot_control = np.random.choice([-1, 0, 1])
            # rot_control = 0
            # vel_control = np.random.choice(np.linspace(0, 1, 10)) if rot_control == 0 else 0
            vel_control = 1 if rot_control == 0 else 0
            self.control = (vel_control, rot_control)

        # make observation with 8 joint angles the same way as in step function
        for _ in range(2):
            self.jnt_angles_repository.append(np.zeros(8))
            self.actions_repository.append(np.zeros(8))  
            self.sensor_data_repository.append(np.zeros(6))  
        self.jnt_angles_repository.append(self.get_jnt_angles())
        self.sensor_data_repository.append(self.get_sensor_data())
        jnt_angles_flat = np.array(self.jnt_angles_repository).flatten()
        sensor_data_flat = np.array(self.sensor_data_repository).flatten()
        actions_flat = np.array(self.actions_repository).flatten()
        control = np.array(self.control)
        super().reset()
        obs = np.concatenate((control, jnt_angles_flat, sensor_data_flat, actions_flat))
        return obs

n_steps=100_000_000
n_envs=16
policy_kwargs = dict(activation_fn=torch.nn.Tanh,
                     net_arch=[dict(pi=[64, 64], vf=[32, 32])],
                     squash_output=True,
                     use_expln=True)
env = make_vec_env(ModAntEnv_V2, n_envs=n_envs, vec_env_cls=SubprocVecEnv)
learner = lambda *args, **kwargs: PPO(*args, policy='MlpPolicy', env=env, n_steps=1024,
                                      batch_size=128, policy_kwargs=policy_kwargs, **kwargs)
model = learner(verbose=1, device='cpu', use_sde=True)
model.learn(total_timesteps=n_steps)

Relevant log output / Error message

--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 1.11e+04 |
|    ep_rew_mean          | 1.16e+04 |
| time/                   |          |
|    fps                  | 1291     |
|    iterations           | 200      |
|    time_elapsed         | 3965     |
|    total_timesteps      | 5120000  |
| train/                  |          |
|    approx_kl            | 33.19446 |
|    clip_fraction        | 0.999    |
|    clip_range           | 0.2      |
|    entropy_loss         | -55.1    |
|    explained_variance   | 0.924    |
|    learning_rate        | 0.0003   |
|    loss                 | -0.0401  |
|    n_updates            | 1990     |
|    policy_gradient_loss | -0.0652  |
|    std                  | 0.152    |
|    value_loss           | 1.05     |
--------------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 1.07e+04      |
|    ep_rew_mean          | 1.13e+04      |
| time/                   |               |
|    fps                  | 1291          |
|    iterations           | 201           |
|    time_elapsed         | 3984          |
|    total_timesteps      | 5145600       |
| train/                  |               |
|    approx_kl            | 46686410000.0 |
|    clip_fraction        | 0.996         |
|    clip_range           | 0.2           |
|    entropy_loss         | -90.4         |
|    explained_variance   | 0.864         |
|    learning_rate        | 0.0003        |
|    loss                 | 0.254         |
|    n_updates            | 2000          |
|    policy_gradient_loss | 0.093         |
|    std                  | 0.152         |
|    value_loss           | 1.65          |
-------------------------------------------
---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 1.08e+04  |
|    ep_rew_mean          | 1.13e+04  |
| time/                   |           |
|    fps                  | 1291      |
|    iterations           | 202       |
|    time_elapsed         | 4003      |
|    total_timesteps      | 5171200   |
| train/                  |           |
|    approx_kl            | 29.075882 |
|    clip_fraction        | 0.991     |
|    clip_range           | 0.2       |
|    entropy_loss         | -119      |
|    explained_variance   | 0.937     |
|    learning_rate        | 0.0003    |
|    loss                 | 0.195     |
|    n_updates            | 2010      |
|    policy_gradient_loss | -0.0658   |
|    std                  | 0.151     |
|    value_loss           | 2.42      |
---------------------------------------
Traceback (most recent call last):
  File "main.py", line 129, in <module>
    model.learn(total_timesteps=n_steps, callback=[checkpoint_callback])
  File "/home/anirudh/venv/lib/python3.7/site-packages/stable_baselines3/ppo/ppo.py", line 314, in learn
    progress_bar=progress_bar,
  File "/home/anirudh/venv/lib/python3.7/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 281, in learn
    self.train()
  File "/home/anirudh/venv/lib/python3.7/site-packages/stable_baselines3/ppo/ppo.py", line 208, in train
    self.policy.reset_noise(self.batch_size)
  File "/home/anirudh/venv/lib/python3.7/site-packages/stable_baselines3/common/policies.py", line 540, in reset_noise
    self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
  File "/home/anirudh/venv/lib/python3.7/site-packages/stable_baselines3/common/distributions.py", line 504, in sample_weights
    self.weights_dist = Normal(th.zeros_like(std), std)
  File "/home/anirudh/venv/lib/python3.7/site-packages/torch/distributions/normal.py", line 50, in __init__
    super(Normal, self).__init__(batch_shape, validate_args=validate_args)
  File "/home/anirudh/venv/lib/python3.7/site-packages/torch/distributions/distribution.py", line 56, in __init__
    f"Expected parameter {param} "
ValueError: Expected parameter scale (Tensor of shape (64, 8)) of distribution Normal(loc: torch.Size([64, 8]), scale: torch.Size([64, 8])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=<ExpBackward0>)

System Info

installed via pip output of python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' :

- OS: Linux-5.18.10-76051810-generic-x86_64-with-Pop-22.04-jammy # 202207071639~1659403207~22.04~cb5f582 SMP PREEMPT_DYNAMIC Tue A
- Python: 3.7.15
- Stable-Baselines3: 2.0.0
- PyTorch: 1.10.0+cu113
- GPU Enabled: True
- Numpy: 1.21.6
- Cloudpickle: 1.2.1
- Gymnasium: 0.28.1
- OpenAI Gym: 0.26.2

Checklist

araffin commented 1 year ago

Hello, does that happens without gSDE? (and with ReLU instead of Tanh)

anirudhs001 commented 1 year ago

heyy, nope, it doesn't happen without gSDE (or hasn't happened so far atleast. Current run has done 20k updates over 60 million timesteps). I haven't tried using ReLU with gSDE yet.

ZNLXJZ commented 1 year ago

Such a problem also occurred when I used ppo, but I thought it was a mistake that occurred when I modified other parts. Now I am not sure.

araffin commented 9 months ago

use_expln=True should solve the problem, see https://github.com/DLR-RM/rl-baselines3-zoo/issues/427#issuecomment-1835541907

arekkulczycki commented 6 months ago

I had this same error, but in my case was caused by mixed precision, running the training within with th.xpu.amp.autocast(enabled=True, dtype=th.bfloat16). In the AdamW optimizer step, the first param found in self.policy.optimizer.param_groups[0]['params'][0], (which I think is loss, but I'm not sure) was being set to nan for some reason. This happened somtimes after many hours of training, apparently when the loss value was getting high (around -40).

edit: After removing the autocast I got the same error again, just 40 minutes of training later. I have use_sde=False, unlike in the issue. Shall I start another issue?

Traceback (most recent call last):
  File "/home/arek/projects/arek-chess/arek_chess/training/run.py", line 267, in <module>
    path = LOG_PATH if not args.env else os.path.join(LOG_PATH, args.env)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/projects/arek-chess/arek_chess/training/run.py", line 147, in train
    model.learn(total_timesteps=TOTAL_TIMESTEPS, reset_num_timesteps=reset_num_timesteps, tb_log_name=env_name, callback=TensorboardActionHistogramCallback())  # progress_bar=True
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 316, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 309, in learn
    self.train()
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 217, in train
    values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/common/policies.py", line 745, in evaluate_actions
    distribution = self._get_action_dist_from_latent(latent_pi)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/common/policies.py", line 701, in _get_action_dist_from_latent
    return self.action_dist.proba_distribution(mean_actions, self.log_std)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/stable_baselines3/common/distributions.py", line 164, in proba_distribution
    self.distribution = Normal(mean_actions, action_std)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/home/arek/.pyenv/versions/env/lib/python3.11/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (261760, 1)) of distribution Normal(loc: torch.Size([261760, 1]), scale: torch.Size([261760, 1])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], device='xpu:0', grad_fn=<AddmmBackward0>)