Closed LemonPasserby closed 7 months ago
CartPole的state是(4,),也就是一维数组长度为4。如果state是二维或以上的话,建议是展开成一维,这是经验做法,如果有图像数据,经过卷积层后同样也是进行flatten操作,展平,你这里(7,4)展开后就是(28,),代码里面最好的做法是在train函数和test函数进行修改(我猜你提到的代码在CartPole(DDQN+PER+DUEL)里),即
state, _ = env.reset(seed=cfg.seed)
和next_state, reward, terminated, truncated, _ = env.step(action)
这两处展平state即可,这样DQN内部实现可以不用动。
好的,懂了,感谢您的回答!
作者你好,我自己创建了个env,state是(7, 4)变量,使用你这里的DDQN算法时,出现报错 "C:\ProgramData\Anaconda3\envs\stableBaseline\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x4 and 7x256) 我检查了下代码,是这里有问题
然后我做了如下改进:
cfg.n_states = env.observation_space.shape[0] * env.observation_space.shape[1]
和choose action函数以及update函数
代码运行几十次没有报错,但因为我对强化学习内部代码不是很熟,所以想请教你一下,我这样改是否有问题,也欢迎后面看到这个话题的能一起讨论下。期待你的回复!