datalogue / keras-attention

Visualizing RNNs using the attention mechanism
https://medium.com/datalogue/attention-in-keras-1892773a4f22
GNU Affero General Public License v3.0
747 stars 243 forks source link

Attention model can not learn simple behaviour #31

Closed asmekal closed 6 years ago

asmekal commented 6 years ago

For a long time I've tried to adapt your model to OCR problem. At some point I found out that even with frozen encoder features, recieved by CTC model (that performs well) I can not reproduce the same performance. Then I made the simpliest problem which any reasonable classifier should solve, kind of autoencoder

    import keras
    from keras.layers import Input, LSTM, Embedding, Dense
    from keras.models import Model

    import numpy as np
    n, t = 10000, 5
    n_labels = 10
    y = np.random.randint(0, n_labels, size=(n,t))

    inp = Input(shape=(t,), dtype='int64')
    emb = Embedding(n_labels, 10)(inp)
    #outp = Dense(n_labels, activation='softmax')(emb)
    outp = AttentionDecoder(10, n_labels)(emb)

    model = Model(inputs=[inp],outputs=[outp])
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    model.fit(y, np.expand_dims(y, -1), epochs=15)

To no surprise, if Dense is used instead of AttentionDecoder we will recieve accuracy = 1 immediately after the first epoch. But with AttentionDecoder model stalemates at around accuracy = 0.5 with no further progress at all.

It seems to be working well only if t <= 2, maybe due to initial_state which is initialized from first timestep: s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s)), and attention is overfitted on the second timestep. But even with t = 3 accuracy does not exceed 0.7, which is close to guessing two labels and returning last one at random.

Any ideas?

asmekal commented 6 years ago

It seems that there are 2 ways to solve the problem and help model to learn:

  1. Use some RNN before decoder:
    ...
    emb = Embedding(n_labels, 10)(inp)
    x = LSTM(units)(emb)
    x = AttentionDecoder(10, n_labels)(x)
    ...
  2. Add information about positions (like position encoding)
    ...
    emb = Embedding(n_labels, 10)(inp)
    pos_emb = PositionEncoding(...)(emb)
    x = concatenate([emb, pos_emb], axis=-1)
    x = AttentionDecoder(10, n_labels)(x)
    ...

Actually these 2 ways are generally the same thing, they are aimed to give some positional imformation which is not known (at least in that toy example). And without that positional information model is unable to fit.

asmekal commented 6 years ago

So the implemented attention is able to fit, although I can not get why is it so hard to move attention always by one step... But anyway the problem is solved so I'll close the issue

zafarali commented 6 years ago

I think your intuition is correct, there needs to be some positional information otherwise the attention decoder makes no sense for what it is receiving (just a bunch of vectors representing independent entries). It then has to look at this input sequence and try to figure out how to decode it but has no information about how each entry relates to the other. Neat experiment, thank you!