[X] I have searched through the issue tracker for duplicates
[X] I have mentioned version numbers, operating system and environment, where applicable:
I try to set args frams_stack=1, but when i run it, it error. Response is blow:
Traceback (most recent call last):
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 290, in
test_ppo(get_args())
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 283, in test_ppo
).run()
^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 441, in run
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 299, in next
self.policy_update_fn(data, result)
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/onpolicy.py", line 132, in policy_update_fn
losses = self.policy.update(
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/base.py", line 277, in update
batch = self.process_fn(batch, buffer, indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/ppo.py", line 92, in process_fn
batch = self._compute_returns(batch, buffer, indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/a2c.py", line 89, in _compute_returns
v_s.append(self.critic(minibatch.obs))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forwardcall(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/utils/net/discrete.py", line 120, in forward
logits, = self.preprocess(obs, state=kwargs.get("state", None))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 27, in forward
return super().forward(obs / denom, state, info)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 89, in forward
return self.net(obs), state
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Given groups=1, weight of size [32, 1, 8, 8], expected input[1, 256, 84, 84] to have 1 channels, but got 256 channels instead
I try to set args frams_stack=1, but when i run it, it error. Response is blow: Traceback (most recent call last): File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 290, in
test_ppo(get_args())
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_ppo.py", line 283, in test_ppo
).run()
^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 441, in run
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/base.py", line 299, in next
self.policy_update_fn(data, result)
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/trainer/onpolicy.py", line 132, in policy_update_fn
losses = self.policy.update(
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/base.py", line 277, in update
batch = self.process_fn(batch, buffer, indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/ppo.py", line 92, in process_fn
batch = self._compute_returns(batch, buffer, indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/policy/modelfree/a2c.py", line 89, in _compute_returns
v_s.append(self.critic(minibatch.obs))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forwardcall(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/tianshou/utils/net/discrete.py", line 120, in forward
logits, = self.preprocess(obs, state=kwargs.get("state", None))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 27, in forward
return super().forward(obs / denom, state, info)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/code/Stolen-RL-Data/tianshou-origin/examples/atari/atari_network.py", line 89, in forward
return self.net(obs), state
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/tianshou/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Given groups=1, weight of size [32, 1, 8, 8], expected input[1, 256, 84, 84] to have 1 channels, but got 256 channels instead