loiccordone / object-detection-with-spiking-neural-networks

Repository code for the IJCNN 2022 paper "Object Detection with Spiking Neural Networks on Automotive Event Data"
MIT License
57 stars 12 forks source link

Bug in forward trained model #26

Closed JNaranjo-Alcazar closed 1 year ago

JNaranjo-Alcazar commented 1 year ago

Hi, I am trying to create an inference file (hopefully i will try to make a pull request).

The inference file has the following piece of code:

for batch in inference_dataloader:
        features, head_outputs = module.forward(batch[0])

where batch[0] is aTorch Size of [T, N, C, H, W]

The problem appears in line 132 of spiking_vgg.py. The error is the following:

  File "/src/inference.py", line 107, in main
    features, head_outputs = module.forward(batch[0])
  File "/src/object_detection_module.py", line 73, in forward
    features = self.backbone(events)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/src/models/detection_backbone.py", line 42, in forward
    feature_maps = self.model(x, classify=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/src/models/spiking_vgg.py", line 132, in forward
    x_seq = functional.seq_to_ann_forward(x, self.features[0])
  File "/opt/conda/lib/python3.8/site-packages/spikingjelly/clock_driven/functional.py", line 568, in seq_to_ann_forward
    y = m(y)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/spikingjelly/clock_driven/neuron.py", line 1114, in forward
    spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply(
  File "/opt/conda/lib/python3.8/site-packages/spikingjelly/clock_driven/neuron_kernel.py", line 1162, in forward
    cu_kernel_opt.wrap_args_to_raw_kernel(
  File "/opt/conda/lib/python3.8/site-packages/spikingjelly/clock_driven/cu_kernel_opt.py", line 64, in wrap_args_to_raw_kernel
    assert item.device.id == device
AssertionError

It seems that is a spikingjelly issue. Any ideas in order to solve this?

Thanks in advance

liberary233 commented 10 months ago

I also encountered this problem······has anyone solved it?