Closed araffin closed 1 year ago
I would like to work and this one and can prepare a PR
I would like to work and this one and can prepare a PR
Please do =), the only thing I'm not sure yet is what to fix (cast action or cast model), what do you think?
A quick note on this issue:
import gym
import numpy as np
from stable_baselines3 import SAC, DDPG
env = gym.make("Pendulum-v1")
np.float64
as dtype instead of np.float32
env.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float64)
model = DDPG("MlpPolicy", env, verbose=1) model.learn(total_timesteps=100)
3. My system Info:
OS: macOS-10.16-x86_64-i386-64bit Python: 3.9.13 Stable-Baselines3: 1.7.0a1 PyTorch: 1.13.0 GPU Enabled: False Numpy: 1.23.4 Gym: 0.21.0
4. The stack trace I got is slightly different from yours, but the cause is probably the same:
Traceback (most recent call last):
File "/Users/tobi/git/stable-baselines3/test.py", line 10, in
DDPG seems to have the same issue.
sure, I would say most off-policy algorithms with continuous actions should be affected.
@araffin Any updates? I still get the same error for a robotics environment that is using TD3. Is there something we could do to mitigate the issue?
In my environment the env_checker
runs without error but when I'm running TD3
I get the RuntimeError: mat1 and mat2 must have the same dtype
error.
Full error:
File "/home/behradx/projects/gym_envs_urdf/examples/To Be Deleted/point_robot_rl_sb3.py", line 83, in <module>
model.learn(total_timesteps=TIMESTEPS,
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/td3/td3.py", line 216, in learn
return super().learn(
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 330, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/td3/td3.py", line 169, in train
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 934, in forward
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 934, in <genexpr>
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/behradx/anaconda3/envs/SB3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype
Hello, the bug is still there but we would welcome a PR that solves it. There is an easy fix which is using float32 instead of float64 for the action space (a gym wrapper can do that).
Hello, the bug is still there but we would welcome a PR that solves it. There is an easy fix which is using float32 instead of float64 for the action space (a gym wrapper can do that).
Few questions first:
np.float
and so on deprecated? Aren't we supposed to use vanilla python data types instead of NumPy? (source)
- I meant writing a custom gym wrapper
- only for the scalar types, numpy dtype are still valid for arrays as far as i know
- We have a Contributing guide for that, take a look at the README ;) https://github.com/DLR-RM/stable-baselines3#how-to-contribute
I used this wrapper below and got algorithm running. May help others in the future:
class Float32ActionWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Box(
low=self.env.action_space.low.astype(np.float32),
high=self.env.action_space.high.astype(np.float32),
dtype=np.float32,
)
def step(self, action):
action = action.astype(np.float32)
obs, reward, done, info = self.env.step(action)
return obs, reward, done, info
Remember to use this wrapper like wrapper_env = Float32ActionWrapper(old_env)
.
And currently I'm working on an important project and don't have spare time. But I'd be glad to be helpful in the future. :)
I would like to work and this one and can prepare a PR
Please do =), the only thing I'm not sure yet is what to fix (cast action or cast model), what do you think?
I guess the core of the issue is how ReplayBuffer
and RolloutBuffer
are storing actions and observations. The RolloutBuffer
for On-Policy algorithms always casts the actions and observations to float32. The ReplayBuffer
on the other hand doesn´t. I would propose fixing this issue for now by casting actions and observations to float32 in the ReplayBuffer
as well. I would prepare an MR in the upcoming days for this.
Anyhow, I would propose that we are looking deeper into this issue later (in a separate MR) because: 1) There is a lot of casting back and forth going on which might affect the performance somewhat. 2) From the outside, it is not clear that the actions and observations are internally cast down to float32. We should either warn the user or even better add "full support" for float64 actions and observations by moving the models to float64.
I would propose fixing this issue for now by casting actions and observations to float32 in the ReplayBuffer as well.
sounds good but only for the Box
case (+ add a note in the doc). (this would prevent using float16 though, that's why I also agree with part2)
I would propose that we are looking deeper into this issue later (in a separate MR) because:
I also agree on that ;) (but I personally have no time for that, adding a wrapper has been a good and simple enough solution so far).
🐛 Bug
Not to forget: despite a warning from the env checker, SB3 SAC throws a not so nice error when using action of type float64.
We should either fix the bug with casting, move the model to float64 or at least throw a nice error.
I found the bug while playing with envpool and
HalfCheetah-v3
.To Reproduce
I will complete that later, basically any env with action space of type float64.
Relevant log output / Error message
System Info
Latest SB3.
Checklist