sweetice / Deep-reinforcement-learning-with-pytorch

PyTorch implementation of DQN, AC, ACER, A2C, A3C, PG, DDPG, TRPO, PPO, SAC, TD3 and ....
MIT License
3.73k stars 837 forks source link

SAC Bugs #25

Open ZiyiLiubird opened 3 years ago

ZiyiLiubird commented 3 years ago

In SAC.py, SAC_BipedalWalker-v2.py, the codes:

class NormalizedActions(gym.ActionWrapper):
    def _action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)

        return action

    def _reverse_action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)

        return action

now should be changed as follows:

class NormalizedActions(gym.ActionWrapper):
    def action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)

        return action

    def reverse_action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)

        return action

in order to adapt to the latest OpenAI Gym core.py

ZiyiLiubird commented 3 years ago

otherwise there will be an overloaded error that " Traceback (most recent call last): File "SAC.py", line 308, in main() File "SAC.py", line 288, in main next_state, reward, done, info = env.step(np.float64(action)) File "/Users/Shared/anaconda3/envs/Pytorch/lib/python3.8/site-packages/gym/core.py", line 285, in step return self.env.step(self.action(action)) File "/Users/Shared/anaconda3/envs/Pytorch/lib/python3.8/site-packages/gym/core.py", line 288, in action raise NotImplementedError NotImplementedError "

hshhsjsj commented 2 years ago

you are amazing,

zhaoyanghandd commented 2 years ago

RuntimeError: Found dtype Double but expected Float 请问该如何解决呢