datalogue / keras-attention

Visualizing RNNs using the attention mechanism
https://medium.com/datalogue/attention-in-keras-1892773a4f22
GNU Affero General Public License v3.0
747 stars 243 forks source link

Attention Decoder for OutputDimension in tens of thousands. #39

Open KushalDave opened 5 years ago

KushalDave commented 5 years ago

Hi Zafarali,

I am trying to use your attention network to learn seq2seq machine translation with attention. My spurce lang output vocab is of size 32,000 and target vocab size 34,000. The following step blows up the RAM usage while making the model (understandably, as its trying to manage a 34K x 34K float matrix):

    self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim),
                               name='W_o',
                               initializer=self.recurrent_initializer,
                               regularizer=self.recurrent_regularizer,
                               constraint=self.recurrent_constraint)

Here is my model: n_units:128, src_vocab_size:32000,tar_vocab_size:34000,src_max_length:11, tar_max_length:11

    def define_model(n_units, src_vocab_size, tar_vocab_size, src_max_length, tar_max_length):
        model = Sequential()
        model.add(Embedding(src_vocab_size, n_units, input_length=src_max_length, mask_zero=True))
        model.add(LSTM(n_units, return_sequences=True))
        model.add(AttentionDecoder(n_units, tar_vocab_size))
        return model

Is there any fix for this?

KushalDave commented 5 years ago

I have tried several things but cant get it working. Adding this weight seems to bloat up the memory over 2G and the code crashes.

zafarali commented 5 years ago

You could try to change the type of the weights to tf.float16 or something with lower precision to save memory.