thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.97k stars 1.13k forks source link

Atari_PPO.py set frames_stack=1 can't run #1006

Open luminous-123 opened 11 months ago

luminous-123 commented 11 months ago

I try to set args frams_stack=1, but when i run it, it error. Response is blow: Traceback (most recent call last): File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 290, in test_ppo(get_args()) File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 283, in test_ppo ).run() ^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 441, in run deque(self, maxlen=0) # feed the entire iterator into a zero-length deque ^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 299, in next self.policy_update_fn(data, result) File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/onpolicy.py", line 132, in policy_update_fn losses = self.policy.update( ^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/base.py", line 277, in update batch = self.process_fn(batch, buffer, indices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/ppo.py", line 92, in process_fn batch = self._compute_returns(batch, buffer, indices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/a2c.py", line 89, in _compute_returns v_s.append(self.critic(minibatch.obs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forwardcall(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/utils/net/discrete.py", line 120, in forward logits, = self.preprocess(obs, state=kwargs.get("state", None)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 27, in forward return super().forward(obs / denom, state, info) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 89, in forward return self.net(obs), state ^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) ^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) ^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward return self._conv_forward(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward return F.conv2d(input, weight, bias, self.stride, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Given groups=1, weight of size [32, 1, 8, 8], expected input[1, 256, 84, 84] to have 1 channels, but got 256 channels instead

MischaPanch commented 10 months ago

Same as for your other question: did you run this from the current master branch? If yes, I will look into the two issues

luminous-123 commented 10 months ago

i install from pip, not current master branch,I will try after