jihunchoi / recurrent-batch-normalization-pytorch

PyTorch implementation of recurrent batch normalization
244 stars 34 forks source link

'<' not supported between instances of 'int' and 'tuple' #9

Closed dtchuink closed 5 years ago

dtchuink commented 6 years ago

Hi thanks for the implementation.

When i run your code, i got the error: TypeError: '<' not supported between instances of 'int' and 'tuple'

at the line: mask = (time < length).float().unsqueeze(1).expand_as(h_next) in the _forward_rnn function of the LSTM class.

Do you have any idea what might cause it?

Thanks.

jihunchoi commented 6 years ago

Please make sure that the type of length is torch.LongTensor! If length is a torch.LongTensor object, then time < length should return the torch.ByteTensor object (where each value indicates whether the condition is true).

dtchuink commented 6 years ago

Thanks for the answer.

However, length is not initialized in train_mnist.py. The only variable we have using while calling bnlstm is max_lenght. How to make sure that length is a torch.LongTensor?

Thanks.

jihunchoi commented 6 years ago

length is defined in the following lines (though not optimized for the latest pytorch release). https://github.com/jihunchoi/recurrent-batch-normalization-pytorch/blob/61736ecd2547bdb43e193ac6aa28545e3918ff9b/bnlstm.py#L275-L279 In the MNIST experiments, the length is always fixed to 576, however when variable length sequences are given then you can explicitly pass length argument to the forward method.

Plus, since the current implementation is based on the old pytorch release, I will update the code to work with the latest one (hopefully very soon).