fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

运行DSQN项目时报错(neuron.py_830行_neuronal_charge()函数变量_tensor类型不匹配) #549

Open huzj206 opened 1 month ago

huzj206 commented 1 month ago

直接运行DSQN项目时(未做任何修改)如题报错 虽然自己可以手动强制转化为tensor,但我不确定是否存在底层函数的bug

感谢作者查看,期待回复

SpikingJelly version

0.0.0.0.14

Description

spikingjelly/activation_based/neuron.py", line 830 RuntimeError: neuronal_charge_decay_input_reset0() Expected a value of type 'Tensor' for argument 'v' but instead found type 'float'. ...

Minimal code to reproduce the error/bug

    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            if self.v_reset is None or self.v_reset == 0.:
                self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
            else:
                self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)

        else:
            if self.v_reset is None or self.v_reset == 0.:
                self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
            else:
                self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
    @staticmethod
    @torch.jit.script
    def neuronal_charge_decay_input_reset0(x: torch.Tensor, v: torch.Tensor, tau: float):
        v = v + (x - v) / tau
        return v
timswim commented 1 month ago

Hello! You wrote that you are using version 0.0.0.0.14, but as far as I understand DSQN project (in examples) works correctly only on the master branch, because some functions are different.

huzj206 commented 1 month ago

Thanks for your reply! I tried in version 0.0.0.0.15 (is it master branch?) but the bug did not fix. Should I have to cast the self.v from float type to Tensor type?

--- Terminate print out

(env2024) z-hu@hirame:~/workplace/dqn_spike/DSQN$ python train.py --cuda --game breakout --T 8 --dec_type max-mem --seed 123

Using GPU
A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
------------------------------------------------
Frame: 0 / 20000000
Max Frame Idx:  0 , Max Reward:  -inf
------------------------------------------------
Traceback (most recent call last):
  File "train.py", line 179, in <module>
    buffer.populate(PLAY_STEPS)
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/ptan/experience.py", line 380, in populate
    entry = next(self.experience_source_iter)
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/ptan/experience.py", line 185, in __iter__
    for exp in super(ExperienceSourceFirstLast, self).__iter__():
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/ptan/experience.py", line 84, in __iter__
    states_actions, new_agent_states = self.agent(states_input, agent_states)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/ptan/agent.py", line 76, in __call__
    q_v = self.dqn_model(states) if 'dqn' in self.dqn_model.model_name else self.dqn_model.qvals(states)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/utils/model.py", line 71, in forward
    return self.network(x_seq)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mil/z-hu/workplace/dqn_spike/DSQN/utils/model.py", line 10, in forward
    self.neuronal_charge(dv)
  File "/home/mil/z-hu/.pyenv/versions/anaconda3-2023.03/envs/env2024/lib/python3.8/site-packages/spikingjelly-0.0.0.0.15-py3.8.egg/spikingjelly/activation_based/neuron.py", line 830, in neuronal_charge
RuntimeError: neuronal_charge_decay_input_reset0() Expected a value of type 'Tensor' for argument 'v' but instead found type 'float'.
Position: 1
Value: 0.0
Declaration: neuronal_charge_decay_input_reset0(Tensor x, Tensor v, float tau) -> Tensor
Cast error details: Unable to cast 0.0 to Tensor
huzj206 commented 1 month ago

I also checked the document of neuron and did a simple test as the following code. https://spikingjelly.readthedocs.io/zh-cn/latest/activation_based_en/neuron.html I think the default type of self.v is float.

import torch
from spikingjelly.activation_based import neuron
from matplotlib import pyplot as plt

lif_layer = neuron.LIFNode()
print(lif_layer.v)
print(type(lif_layer.v))

---print out 0.0 <class 'float'>

Maybe the parameter in neuronal_charge_decay_input_reset0() is wrong? (in neuron.py_line 840~844) looking forward to your reply :)

    @staticmethod
    @torch.jit.script
    def neuronal_charge_decay_input_reset0(x: torch.Tensor, v: torch.Tensor, tau: float):
        v = v + (x - v) / tau
        return v
timswim commented 1 month ago

Hello! 1) "I tried in version 0.0.0.0.15 (is it master branch?)". I work in the master branch directly as a project directory, i.e. I don't install spikingjelly into the environment. 2) Here i have the same output.

import torch from spikingjelly.activation_based import neuron from matplotlib import pyplot as plt

lif_layer = neuron.LIFNode() print(lif_layer.v) print(type(lif_layer.v))



---print out 0.0 <class 'float'>
3) I tried to find **self.neuronal_charge(dv)** in **line 10 in DSQN/utils/model.py"** following your output, but i don't have it...

https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/examples/DSQN/utils/model.py