Open FabbyD opened 7 years ago
I just found out that this actually seems to happen for a few other methods that iterate through clones with AbstractRecurrent:includingSharedClones(f)
. In order to reproduce this issue, all you need is to wrap a nn.Container
containing any nn.AbstractRecurrent
module with nn.Recursor
.
Actually, why is clearState not removing all clones? After all, clones are mostly used to keep those extra output and gradInput buffers. Calling clearState IMO means: "Remove all buffers from memory. I will reallocate afterwards if necessary."
Example:
The last line will clear all 10 clones of
nn.Sequential
. Every one of these clones will also clear all thenn.LSTM
clones. Since thenn.LSTM
module takes care of its clones internally, we are clearing the same 10 clones 10 times.By adding a few prints in
AbstractRecurrent:clearState()
we get something like this:This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.
Because
nn.LSTM
manages its clones internally and is contained inside thenn.Sequential
, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?