keitakurita / Better_LSTM_PyTorch

An LSTM in PyTorch with best practices (weight dropout, forget bias, etc.) built-in. Fully compatible with PyTorch LSTM.
MIT License
133 stars 20 forks source link

a little problem when batch_first=False. #6

Open hunterbobo opened 4 years ago

hunterbobo commented 4 years ago

To fix the error to support batch_first=False. if is_packed: x, batch_sizes = x max_batch_size = int(batch_sizes[0]) else: batch_sizes = None max_batch_size = x.size(0)

=======> if is_packed: x, batch_sizes = x if self.batch_first: max_batch_size = int(batch_sizes[0]) else: max_batch_size = int(batch_sizes[1]) else: batch_sizes = None if self.batch_first: max_batch_size = x.size(0) else: max_batch_size = x.size(1)