jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.35k stars 222 forks source link

AttributeError: '_SpikeTensor' object has no attribute 'is_mps' #202

Closed xjtulyc closed 1 year ago

xjtulyc commented 1 year ago

Description

When I tried to run snntorch/examples/tutorial_7_neuromorphic_datasets.ipynb, I got an error training the neural network

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_20075/2455384482.py in <module>
     13 
     14         net.train()
---> 15         spk_rec = forward_pass(net, data)
     16         loss_val = loss_fn(spk_rec, targets)
     17 

/tmp/ipykernel_20075/3960929974.py in forward_pass(net, data)
      6 
      7   for step in range(data.size(0)):  # data.size(0) = number of time steps
----> 8       spk_out, mem_out = net(data[step])
      9       spk_rec.append(spk_out)
     10 

~/anaconda3/envs/d/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/d/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
...
--> 462         elif arg.is_mps:
    463             arg = arg.to("cpu")
    464         arg = torch.Tensor(arg)  # wash away the SpikeTensor class

AttributeError: '_SpikeTensor' object has no attribute 'is_mps'

How can I solve this problem?

jeshraghian commented 1 year ago

For an immediate fix, update your version of PyTorch to PyTorch>=1.13.

I'll try to have an alternative fix that enables compatibility with older versions shortly.

xjtulyc commented 1 year ago

For an immediate fix, update your version of PyTorch to PyTorch>=1.13.

I'll try to have an alternative fix that enables compatibility with older versions shortly.

Thank you very much for your quick reply.

jeshraghian commented 1 year ago

Fixed in snntorch==0.6.2 onwards. You do not need to update PyTorch if you update snnTorch. See #200 for more details.