Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
938 stars 314 forks source link

AbstractRecurrent:clearState() clearing too much #395

Open FabbyD opened 7 years ago

FabbyD commented 7 years ago

Example:

local net = nn.Sequencer(
  nn.Sequential()
    :add(nn.LSTM(100,100))
    :add(nn.Linear(100,100))
    :add(nn.LSTM(100,100))
  )

local inputs = {}
for i=1,10 do
  table.insert(inputs, torch.randn(100))
end

net:forward(inputs) -- This should create 10 clones of my network

net:clearState()

The last line will clear all 10 clones of nn.Sequential. Every one of these clones will also clear all the nn.LSTM clones. Since the nn.LSTM module takes care of its clones internally, we are clearing the same 10 clones 10 times.

By adding a few prints in AbstractRecurrent:clearState() we get something like this:

clearState nn.Recursor @ nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): nn.LSTM(100 -> 100)
  (2): nn.Linear(100 -> 100)
  (3): nn.LSTM(100 -> 100)
}   
clearState nn.LSTM(100 -> 100)  
  cleared clone 1 in 0.00020408630371094    
  cleared clone 2 in 0.00022411346435547    
  cleared clone 3 in 0.0002291202545166 
  cleared clone 4 in 0.00021004676818848    
  cleared clone 5 in 0.00021100044250488    
  cleared clone 6 in 0.00018811225891113    
  cleared clone 7 in 0.00021791458129883    
  cleared clone 8 in 0.0002140998840332 
  cleared clone 9 in 0.00021195411682129    
  cleared clone 10 in 0.00020408630371094   
clearState nn.LSTM(100 -> 100)  
  cleared clone 1 in 0.0001978874206543 
  cleared clone 2 in 0.00060486793518066    
  cleared clone 3 in 0.00049901008605957    
  cleared clone 4 in 0.0002589225769043 
  cleared clone 5 in 0.00022697448730469    
  cleared clone 6 in 0.00019097328186035    
  cleared clone 7 in 0.00020694732666016    
  cleared clone 8 in 0.00022196769714355    
  cleared clone 9 in 0.00023078918457031    
  cleared clone 10 in 0.00024318695068359   
  cleared clone 1 in 0.0056848526000977    <-- The first nn.Sequential clone
... 
clearState nn.LSTM(100 -> 100)  
  cleared clone 1 in 0.00015807151794434    
  cleared clone 2 in 0.00019478797912598    
  cleared clone 3 in 0.00017786026000977    
  cleared clone 4 in 0.00020194053649902    
  cleared clone 5 in 0.00017094612121582    
  cleared clone 6 in 0.00017809867858887    
  cleared clone 7 in 0.00016403198242188    
  cleared clone 8 in 0.00015807151794434    
  cleared clone 9 in 0.00016117095947266    
  cleared clone 10 in 0.00016188621520996   
clearState nn.LSTM(100 -> 100)  
  cleared clone 1 in 0.00016117095947266    
  cleared clone 2 in 0.00016403198242188    
  cleared clone 3 in 0.00015997886657715    
  cleared clone 4 in 0.00016307830810547    
  cleared clone 5 in 0.00016498565673828    
  cleared clone 6 in 0.00015807151794434    
  cleared clone 7 in 0.00015902519226074    
  cleared clone 8 in 0.00016093254089355    
  cleared clone 9 in 0.00016188621520996    
  cleared clone 10 in 0.00016617774963379   
  cleared clone 10 in 0.0038068294525146   <-- The last nn.Sequential clone

This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.

Because nn.LSTM manages its clones internally and is contained inside the nn.Sequential, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?

FabbyD commented 7 years ago

I just found out that this actually seems to happen for a few other methods that iterate through clones with AbstractRecurrent:includingSharedClones(f). In order to reproduce this issue, all you need is to wrap a nn.Container containing any nn.AbstractRecurrent module with nn.Recursor.

FabbyD commented 7 years ago

Actually, why is clearState not removing all clones? After all, clones are mostly used to keep those extra output and gradInput buffers. Calling clearState IMO means: "Remove all buffers from memory. I will reallocate afterwards if necessary."