dllllb / pytorch-lifestream

A library built upon PyTorch for building embeddings on discrete event sequences using self-supervision
Apache License 2.0
220 stars 49 forks source link

RnnSeqEncoder doesn't work with num_layers>1 #107

Closed danielgafni closed 1 year ago

danielgafni commented 1 year ago

Getting this error when increasing the number of RNN layers from 1 to 3:

File ~/.cache/pypoetry/virtualenvs/ar-resp-vArrGwjy-py3.10/lib/python3.10/site-packages/ptls/nn/seq_encoder/rnn_encoder.py:134, in RnnEncoder.forward(self, x, h_0)
    132     out, _ = self.rnn(x.payload)
    133 elif self.rnn_type == 'gru':
--> 134     out, _ = self.rnn(x.payload, h_0)
    135 else:
    136     raise Exception(f'wrong rnn type "{self.rnn_type}"')

File ~/.cache/pypoetry/virtualenvs/ar-resp-vArrGwjy-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/pypoetry/virtualenvs/ar-resp-vArrGwjy-py3.10/lib/python3.10/site-packages/torch/nn/modules/rnn.py:953, in GRU.forward(self, input, hx)
    948 else:
    949     # Each batch of the hidden state should match the input sequence that
    950     # the user believes he/she is passing in.
    951     hx = self.permute_hidden(hx, sorted_indices)
--> 953 self.check_forward_args(input, hx, batch_sizes)
    954 if batch_sizes is None:
    955     result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
    956                      self.dropout, self.training, self.bidirectional, self.batch_first)

File ~/.cache/pypoetry/virtualenvs/ar-resp-vArrGwjy-py3.10/lib/python3.10/site-packages/torch/nn/modules/rnn.py:237, in RNNBase.check_forward_args(self, input, hidden, batch_sizes)
    234 self.check_input(input, batch_sizes)
    235 expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
--> 237 self.check_hidden_size(hidden, expected_hidden_size)

File ~/.cache/pypoetry/virtualenvs/ar-resp-vArrGwjy-py3.10/lib/python3.10/site-packages/torch/nn/modules/rnn.py:231, in RNNBase.check_hidden_size(self, hx, expected_hidden_size, msg)
    228 def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
    229                       msg: str = 'Expected hidden size {}, got {}') -> None:
    230     if hx.size() != expected_hidden_size:
--> 231         raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))

RuntimeError: Expected hidden size (3, 128, 32), got [1, 128, 32]
ivkireev86 commented 1 year ago

Fixed in #109