mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
490 stars 64 forks source link

RNN with packed sequences where enforce_sorted=TRUE gives an error. #1156

Open gavril0 opened 4 months ago

gavril0 commented 4 months ago

Using a RNN with packed sequences where enforce_sorted=TRUE gives an error.

Let's define a 3d tensor width dimensions (batch_size, max_len, embedding_size) that represent two embedded sequences of lengths 4 and 2 respectively.

# padded input tensor
batch_size <- 2
input_size <- 3
seq_len <- c(4, 2) # sequence lengths
padded <- torch_randn(batch_size,max(seq_len), input_size)
padded[2,3:4,] <- 0 # padding
padded
torch_tensor
(1,.,.) = 
 -1.0758 -0.5305  1.6832
 -0.1549  2.0737  0.4338
  1.4333  0.5613 -0.5021
  1.2121  0.1815  0.2522

(2,.,.) = 
 -1.3125  0.4738  0.4393
  0.6843 -1.1598  0.2858
  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000
[ CPUFloatType{2,4,3} ]

The length of the two sequences is decreasing and the second sequence is padded with 0.

When the padded sequences are packed with the option enforce_sorted=TRUE, the RNN module gives an error:

# define rnn module
hidden_size <-3
rnn <- nn_rnn(input_size, hidden_size, batch_first=TRUE)
# pack padded input
packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), 
    batch_first=TRUE, enforce_sorted=TRUE)
# RNN
out <- rnn(packed)  
Error in (function (self, other, alpha) : 
  Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'

When padded sequences are packed with enforce_sorted=FALSE, they are processed by the RNN without problem.

# pack padded input
packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), 
    batch_first=TRUE, enforce_sorted=FALSE)
# RNN
out <- rnn(packed)  
out

To show that the output is correct, the first element output of the RNN needs to be unpacked.

nn_utils_rnn_pad_packed_sequence(out[[1]],  batch_first = TRUE, padding_value = 0)
[[1]]
torch_tensor
(1,.,.) = 
  0.4470  0.9314 -0.0264
 -0.8938  0.1120  0.7772
 -0.9523  0.4127  0.0077
 -0.9207  0.6005  0.2798

(2,.,.) = 
  0.1397  0.6618 -0.1076
 -0.2141  0.7277 -0.2181
  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000
[ CPUFloatType{2,4,3} ][ grad_fn = <IndexSelectBackward0> ]

[[2]]
torch_tensor
 4
 2
[ CPULongType{2} ]

As expected, the values of the hidden state is zero after the end of the second (shorter) sequence.

The error also occurs if batch_first=FALSE

# new padded tensor 
padded <- torch_randn(max(seq_len), batch_size, input_size)
padded[3:4,,2] <- 0 # padding
# rnn module with batch_first=FALSE
rnn <- nn_rnn(input_size, hidden_size, batch_first=FALSE)
# pack padded input
packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), 
    batch_first=FALSE, enforce_sorted=TRUE)
out <- rnn(packed)  
sessionInfo()
R version 4.3.1 (2023-06-16 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)

other attached packages:
[1] torch_0.11.0