The line ula, (h_out, _) = self.lstm(x, (h_0, c_0)) causes an exception to be thrown on my machine, i replaced that line with ula, (h_out, _) = self.lstm(x.view(len(x), self.seq_length, -1), (h_0, c_0)). The dimension of the input x should be 3 dimensions while the first line provides only two, the latter line actually modifies the dimension.
The line
ula, (h_out, _) = self.lstm(x, (h_0, c_0))
causes an exception to be thrown on my machine, i replaced that line withula, (h_out, _) = self.lstm(x.view(len(x), self.seq_length, -1), (h_0, c_0))
. The dimension of the inputx
should be 3 dimensions while the first line provides only two, the latter line actually modifies the dimension.