DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.23k stars 1.71k forks source link

[Bug] SAC crashes when float64 action is used #1145

Closed araffin closed 1 year ago

araffin commented 2 years ago

🐛 Bug

Not to forget: despite a warning from the env checker, SB3 SAC throws a not so nice error when using action of type float64.

We should either fix the bug with casting, move the model to float64 or at least throw a nice error.

I found the bug while playing with envpool and HalfCheetah-v3.

To Reproduce

I will complete that later, basically any env with action space of type float64.

Relevant log output / Error message

Traceback (most recent call last):
  File "train.py", line 4, in <module>
    train()
  File "/home/antonin/Documents/dlr/rl/torchy-zoo/rl_zoo3/train.py", line 260, in train
    exp_manager.learn(model)
  File "/home/antonin/Documents/dlr/rl/torchy-zoo/rl_zoo3/exp_manager.py", line 243, in learn
    model.learn(self.n_timesteps, **kwargs)
  File "/home/antonin/Documents/dlr/rl/torchy-baselines/stable_baselines3/sac/sac.py", line 305, in learn
    progress_bar=progress_bar,
  File "/home/antonin/Documents/dlr/rl/torchy-baselines/stable_baselines3/common/off_policy_algorithm.py", line 353, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/antonin/Documents/dlr/rl/torchy-baselines/stable_baselines3/sac/sac.py", line 250, in train
    current_q_values = self.critic(replay_data.observations, replay_data.actions)
  File "/home/antonin/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/antonin/Documents/dlr/rl/torchy-baselines/stable_baselines3/common/policies.py", line 860, in forward
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/home/antonin/Documents/dlr/rl/torchy-baselines/stable_baselines3/common/policies.py", line 860, in <genexpr>
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/home/antonin/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/antonin/miniconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/antonin/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/antonin/miniconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Double

System Info

Latest SB3.

Checklist

tobirohrer commented 2 years ago

I would like to work and this one and can prepare a PR

araffin commented 2 years ago

I would like to work and this one and can prepare a PR

Please do =), the only thing I'm not sure yet is what to fix (cast action or cast model), what do you think?

tobirohrer commented 2 years ago

A quick note on this issue:

  1. DDPG and TD3 seem to have the same issue. PPO and A2C, don't seem to have that issue.
  2. I reproduced it by the following code snipped
    
    import gym
    import numpy as np

from stable_baselines3 import SAC, DDPG

env = gym.make("Pendulum-v1")

Just set the same action space but use np.float64 as dtype instead of np.float32

env.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float64)

model = DDPG("MlpPolicy", env, verbose=1) model.learn(total_timesteps=100)

3. My system Info:

OS: macOS-10.16-x86_64-i386-64bit Python: 3.9.13 Stable-Baselines3: 1.7.0a1 PyTorch: 1.13.0 GPU Enabled: False Numpy: 1.23.4 Gym: 0.21.0

4. The stack trace I got is slightly different from yours, but the cause is probably the same:

Traceback (most recent call last): File "/Users/tobi/git/stable-baselines3/test.py", line 10, in model.learn(total_timesteps=10000, log_interval=4) File "/Users/tobi/git/stable-baselines3/stable_baselines3/sac/sac.py", line 299, in learn return super().learn( File "/Users/tobi/git/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 353, in learn self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) File "/Users/tobi/git/stable-baselines3/stable_baselines3/sac/sac.py", line 250, in train current_q_values = self.critic(replay_data.observations, replay_data.actions) File "/Users/tobi/opt/miniconda3/envs/sb3dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, kwargs) File "/Users/tobi/git/stable-baselines3/stable_baselines3/common/policies.py", line 860, in forward return tuple(q_net(qvalue_input) for q_net in self.q_networks) File "/Users/tobi/git/stable-baselines3/stable_baselines3/common/policies.py", line 860, in return tuple(q_net(qvalue_input) for q_net in self.q_networks) File "/Users/tobi/opt/miniconda3/envs/sb3dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, *kwargs) File "/Users/tobi/opt/miniconda3/envs/sb3dev/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward input = module(input) File "/Users/tobi/opt/miniconda3/envs/sb3dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(input, kwargs) File "/Users/tobi/opt/miniconda3/envs/sb3dev/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 must have the same dtype

araffin commented 2 years ago

DDPG seems to have the same issue.

sure, I would say most off-policy algorithms with continuous actions should be affected.

behradkhadem commented 1 year ago

@araffin Any updates? I still get the same error for a robotics environment that is using TD3. Is there something we could do to mitigate the issue? In my environment the env_checker runs without error but when I'm running TD3 I get the RuntimeError: mat1 and mat2 must have the same dtype error.

Full error:

  File "/home/behradx/projects/gym_envs_urdf/examples/To Be Deleted/point_robot_rl_sb3.py", line 83, in <module>
    model.learn(total_timesteps=TIMESTEPS, 
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/td3/td3.py", line 216, in learn
    return super().learn(
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 330, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/td3/td3.py", line 169, in train
    next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 934, in forward
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 934, in <genexpr>
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype
araffin commented 1 year ago

Hello, the bug is still there but we would welcome a PR that solves it. There is an easy fix which is using float32 instead of float64 for the action space (a gym wrapper can do that).

behradkhadem commented 1 year ago

Hello, the bug is still there but we would welcome a PR that solves it. There is an easy fix which is using float32 instead of float64 for the action space (a gym wrapper can do that).

Few questions first:

  1. What is the name of the wrapper you mentioned? I'm using gymnasium and I didn't find what you said.
  2. Aren't np.float and so on deprecated? Aren't we supposed to use vanilla python data types instead of NumPy? (source)
  3. Are there any guidelines on how can I contribute to this project? I'd be happy to!
araffin commented 1 year ago
  1. I meant writing a custom gym wrapper
  2. only for the scalar types, numpy dtype are still valid for arrays as far as i know
  3. We have a Contributing guide for that, take a look at the README ;) https://github.com/DLR-RM/stable-baselines3#how-to-contribute
behradkhadem commented 1 year ago
  1. I meant writing a custom gym wrapper
  2. only for the scalar types, numpy dtype are still valid for arrays as far as i know
  3. We have a Contributing guide for that, take a look at the README ;) https://github.com/DLR-RM/stable-baselines3#how-to-contribute

I used this wrapper below and got algorithm running. May help others in the future:

class Float32ActionWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)
        self.action_space = gym.spaces.Box(
            low=self.env.action_space.low.astype(np.float32),
            high=self.env.action_space.high.astype(np.float32),
            dtype=np.float32,
        )

    def step(self, action):
        action = action.astype(np.float32)
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

Remember to use this wrapper like wrapper_env = Float32ActionWrapper(old_env).

And currently I'm working on an important project and don't have spare time. But I'd be glad to be helpful in the future. :)

tobirohrer commented 1 year ago

I would like to work and this one and can prepare a PR

Please do =), the only thing I'm not sure yet is what to fix (cast action or cast model), what do you think?

I guess the core of the issue is how ReplayBuffer and RolloutBuffer are storing actions and observations. The RolloutBuffer for On-Policy algorithms always casts the actions and observations to float32. The ReplayBuffer on the other hand doesn´t. I would propose fixing this issue for now by casting actions and observations to float32 in the ReplayBuffer as well. I would prepare an MR in the upcoming days for this.

Anyhow, I would propose that we are looking deeper into this issue later (in a separate MR) because: 1) There is a lot of casting back and forth going on which might affect the performance somewhat. 2) From the outside, it is not clear that the actions and observations are internally cast down to float32. We should either warn the user or even better add "full support" for float64 actions and observations by moving the models to float64.

araffin commented 1 year ago

I would propose fixing this issue for now by casting actions and observations to float32 in the ReplayBuffer as well.

sounds good but only for the Box case (+ add a note in the doc). (this would prevent using float16 though, that's why I also agree with part2)

I would propose that we are looking deeper into this issue later (in a separate MR) because:

I also agree on that ;) (but I personally have no time for that, adding a wrapper has been a good and simple enough solution so far).