harvardnlp / seq2seq-attn

Sequence-to-sequence model with LSTM encoder/decoders and attention
http://nlp.seas.harvard.edu/code
MIT License
1.26k stars 278 forks source link

Memory optimization #45

Closed jsenellart closed 8 years ago

jsenellart commented 8 years ago

The optimization is very simple: during the training we have as many decoders as maximal target words so that we can back-propagate in time. Also, with nngraph which is an abstract layer building the computation graph, each intermediate calculation node is an object for which we keep the full structure in memory – in particular the output field (the result of the operation), and gradInput (the gradient calculated on input). The last one is specially problematic when we do large operations – in the attn function for instance we are multiplying the context with the attention vector. The size of the context is batch_l x source_l x rnn_size. So is the same size of the gradInput on the context.

So for batch_l=64, rnn_size=600, max_sent_l=52, in memory – this takes up to 64 x 52 (50+2) x 600 x size_of float x number of decoders => 830Mb static memory usage for one single intermediate variable. This is significant on GPU.

There is already some optimization existing – the :reuseMem() method on some operations, but what it only does is pretty limited: it reuses the same tensor for output and gradInput – which saves 50% of the space but only works for the operators for which output and gradInput have the same dimension.

The patch is to pre-allocate the memory for these objects (gradInput or output can be specified) – so all the cloned decoders share the same tensor which gives us a max_sent_l_src reduction rate. This has no impact since all the intermediate variables in the nngraph do not have lifecycle beyond the :forward() or backward:() calculation they are involved in.

This preallocation on the two main MM operations of attention decoder gives a memory reduction of 1.6Gb – we can even do a little more by apply the same operator for other operations.

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID  Type  Process name                               Usage      |
|=============================================================================|
|    1     17141    C   /home/shared/lib/torch/f95379d/bin/luajit     5243MiB |
|    2     16830    C   /home/shared/lib/torch/f95379d/bin/luajit     6877MiB |
+-----------------------------------------------------------------------------+
yoonkim commented 8 years ago

this is really cool. merged!