Closed jeshraghian closed 1 year ago
Fix:
def _reshape_input(self, input_):
device = input_.device
b, _ = input_.size()
return torch.zeros(b, self.hidden_size).to(device)
To-do: Check if SConv2dLSTM has the same bug.
Whoops, I fixed this 3 weeks ago in the v0.6.1 update, but forgot to update it locally...
Only
cuda
andcpu
are accounted for as devices types in SLSTM.See line 295 of
slstm.py
: