Starlight0798 / gymRL

基于gym的pytorch深度强化学习(DRL)(PPO,PPG,DQN,SAC,DDPG,TD3等算法)
64 stars 10 forks source link

DDQN处理多维状态空间 #1

Closed LemonPasserby closed 7 months ago

LemonPasserby commented 7 months ago

作者你好,我自己创建了个env,state是(7, 4)变量,使用你这里的DDQN算法时,出现报错 "C:\ProgramData\Anaconda3\envs\stableBaseline\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x4 and 7x256) 我检查了下代码,是这里有问题


    def choose_action(self, state):
        self.sample_count += 1
        self.epsilon = self.cfg.epsilon_end + (self.cfg.epsilon_start - self.cfg.epsilon_end) * \
                       np.exp(-1. * self.sample_count / self.cfg.epsilon_decay)
        if random.uniform(0, 1) > self.epsilon:
            **state = torch.tensor(np.array([state]), device=self.cfg.device, dtype=torch.float32)**
            action = self.policy_net(state).argmax(dim=1).item()
        else:
            action = random.randrange(self.cfg.n_actions)
        return action

然后我做了如下改进: cfg.n_states = env.observation_space.shape[0] * env.observation_space.shape[1] 和choose action函数

    def choose_action(self, state):
        print("选取动作")
        self.sample_count += 1
        self.epsilon = self.cfg.epsilon_end + (self.cfg.epsilon_start - self.cfg.epsilon_end) * \
                       np.exp(-1. * self.sample_count / self.cfg.epsilon_decay)
        if random.uniform(0, 1) > self.epsilon:
            **state = torch.tensor(state.flatten(), device=self.cfg.device, dtype=torch.float32)
            state = state.unsqueeze(0)**
            action = self.policy_net(state).argmax(dim=1).item()
        else:
            action = random.randrange(self.cfg.n_actions)

        return action

以及update函数

    def update(self):

        if self.memory.size() < self.cfg.batch_size:
            return 0
        print("开始更新")
        (state_batch, action_batch, reward_batch, next_state_batch,
            done_batch), idxs_batch, is_weight_batch = self.memory.sample()
        action_batch = action_batch.type(torch.long).view(-1, 1)
        **state_batch = state_batch.view(256, -1)
        next_state_batch = state_batch.view(256, -1)**
        q_value = self.policy_net(state_batch).gather(1, action_batch).squeeze(1)

代码运行几十次没有报错,但因为我对强化学习内部代码不是很熟,所以想请教你一下,我这样改是否有问题,也欢迎后面看到这个话题的能一起讨论下。期待你的回复!

Starlight0798 commented 7 months ago

CartPole的state是(4,),也就是一维数组长度为4。如果state是二维或以上的话,建议是展开成一维,这是经验做法,如果有图像数据,经过卷积层后同样也是进行flatten操作,展平,你这里(7,4)展开后就是(28,),代码里面最好的做法是在train函数和test函数进行修改(我猜你提到的代码在CartPole(DDQN+PER+DUEL)里),即 state, _ = env.reset(seed=cfg.seed)next_state, reward, terminated, truncated, _ = env.step(action) 这两处展平state即可,这样DQN内部实现可以不用动。

LemonPasserby commented 6 months ago

好的,懂了,感谢您的回答!