jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.35k stars 222 forks source link

SLSTM cuda/cpu/mps bug #198

Closed jeshraghian closed 1 year ago

jeshraghian commented 1 year ago

Only cuda and cpu are accounted for as devices types in SLSTM.

See line 295 of slstm.py:

def _reshape_input(self, input_):
        if input_.is_cuda:
            device = "cuda"
        else:
            device = "cpu"
        b, _ = input_.size()
        return torch.zeros(b, self.hidden_size).to(device)
jeshraghian commented 1 year ago

Fix:

def _reshape_input(self, input_):
        device = input_.device
        b, _ = input_.size()
        return torch.zeros(b, self.hidden_size).to(device)

To-do: Check if SConv2dLSTM has the same bug.

jeshraghian commented 1 year ago

Whoops, I fixed this 3 weeks ago in the v0.6.1 update, but forgot to update it locally...