jiayihu / gmail-smart-compose

A study implementation of Gmail Smart Compose trained with Keras and used in browser with Tensorflow.js
MIT License
26 stars 8 forks source link

Adding an attention layer to existing infra #10

Open Tlazypanda opened 3 years ago

Tlazypanda commented 3 years ago

Hello @jiayihu 👋 I am an undergrad student looking to improve this infra by adding an attention layer. But am facing some difficulties with the code.

Here is what I have so far can you help me out?

Model

# GRU Encoder
encoder_in_layer = keras.layers.Input(shape=(max_length_in,))
encoder_embedding = keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)
encoder_bi_gru = keras.layers.Bidirectional(keras.layers.GRU(units=latent_dim, return_sequences=True, return_state=True))

# Discard the encoder output and use hidden states (h) and memory cells states (c)
# for forward (f) and backward (b) layer
encoder_out, fstate_h, bstate_h = encoder_bi_gru(encoder_embedding(encoder_in_layer))
state_h = keras.layers.Concatenate()([fstate_h, bstate_h])

# GRUDecoder
decoder_in_layer = keras.layers.Input(shape=(None,))
decoder_embedding = keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)
decoder_gru = keras.layers.GRU(units=latent_dim * 2, return_sequences=True, return_state=True)

# Discard internal states in training, keep only the output sequence
decoder_gru_out, _ = decoder_gru(decoder_embedding(decoder_in_layer), initial_state=state_h)

# Define the model that uses the Encoder and the Decoder
attn_layer = keras.layers.AdditiveAttention(name='attention_layer')
attn_out = attn_layer([encoder_out, decoder_gru_out])

# Concat attention input and decoder GRU output
decoder_concat_input = keras.layers.Concatenate(axis=-1, name='concat_layer')([decoder_gru_out, attn_out])

# Dense layer
dense = keras.layers.Dense(vocab_size, activation='softmax', name='softmax_layer')
dense_time = keras.layers.TimeDistributed(dense, name='time_distributed_layer')
decoder_pred = dense_time(decoder_concat_input)

model = keras.models.Model([encoder_in_layer, decoder_in_layer], decoder_pred)

def perplexity(y_true, y_pred):
    return keras.backend.exp(keras.backend.mean(keras.backend.sparse_categorical_crossentropy(y_true, y_pred)))

model.compile(optimizer='adam', loss="sparse_categorical_crossentropy", metrics=[perplexity])
model.summary() 

Training

have changed the dim for decoder since they need to match.

epochs = 10
print(encoder_input_train.shape)
print(decoder_input_train.shape)
decoder_input_train_x = np.hstack((decoder_input_train, np.tile(decoder_input_train[:, [-1]], 1)))
decoder_target_train_x = np.hstack((decoder_target_train, np.tile(decoder_target_train[:, [-1]], 1)))
print(decoder_input_train_x.shape)
print(decoder_target_train_x.shape)
history = model.fit([encoder_input_train, decoder_input_train_x], decoder_target_train_x,
                 batch_size=batch_size,
                 epochs=epochs,
                 validation_split=0.2)

def plot_history(history):
  plt.plot(history.history['loss'], label="Training loss")
  plt.plot(history.history['val_loss'], label="Validation loss")
  plt.legend()

plot_history(history)

Inference model

""" Encoder (Inference) model """
encoder_inf_inputs = keras.layers.Input(shape=(max_length_in,))
encoder_inf_out, encoder_inf_fwd_state, encoder_inf_back_state = encoder_bi_gru(encoder_embedding(encoder_inf_inputs))
encoder_model = keras.Model(inputs=encoder_inf_inputs, outputs=[encoder_inf_out, encoder_inf_fwd_state, encoder_inf_back_state])

""" Decoder (Inference) model """
decoder_inf_inputs = keras.layers.Input(shape=(None,))
#state_h = keras.layers.Concatenate()([fstate_h, bstate_h])
state_input_h = keras.layers.Input(shape=(latent_dim * 2,))
print(state_h.shape)
print(encoder_inf_out.shape)
print(decoder_embedding(decoder_inf_inputs).shape)
encoder_inf_states = keras.layers.Input(shape=(max_length_in, latent_dim*2,), name='encoder_inf_states')
#decoder_init_state = Input(batch_shape=(batch_size, 2*hidden_size), name='decoder_init')

decoder_inf_out, decoder_inf_state = decoder_gru(decoder_embedding(decoder_inf_inputs), initial_state=state_input_h)
attn_inf_out = attn_layer([encoder_inf_states, decoder_inf_out])
decoder_inf_concat = keras.layers.Concatenate(axis=-1, name='concat')([decoder_inf_out, attn_inf_out])
decoder_inf_pred = keras.layers.TimeDistributed(dense)(decoder_inf_concat)
inf_model = keras.models.Model([ state_input_h, encoder_inf_states, decoder_inf_inputs], [decoder_inf_pred, decoder_inf_state])
#inf_model = keras.Model(inputs=[encoder_inf_states, decoder_init_state, decoder_inf_inputs],
#                          outputs=[decoder_inf_pred, attn_inf_out, decoder_inf_state])
inf_model.summary()

Final output

def decode_sequence(input_tensor):
    # Encode the input as state vectors.
    encoder_inf_out, f_state, b_state = encoder_model.predict(input_tensor)
    dec_state = np.concatenate([f_state,b_state], axis=-1)
    print(dec_state.shape)

    target_seq = np.zeros((1, 1))
    target_seq[0, 0] = tokenizer.word_index['<start>']
    curr_word = "<start>"
    decoded_sentence = ''
    print(decoder_embedding(target_seq).shape)

    i = 0
    while curr_word != "<end>" and i < (max_length_out - 1):
        output_tokens, h = inf_model.predict([dec_state, encoder_inf_out, target_seq ])
        print("Aqui")

        curr_token = np.argmax(output_tokens[0, 0])

        if (curr_token == 0):
          break;

        curr_word = index_to_word[curr_token]

        decoded_sentence += ' ' + curr_word
        target_seq[0, 0] = curr_token
        state = h
        i += 1

    return decoded_sentence
Tlazypanda commented 3 years ago

Hey @jiayihu This code is failing at inf_model.predict for gru_layer with the error that only 1 input tensor is passed instead of expected 4. I would really appreciate it if you can have a look and let me know if my approach for the same is correct or not? 😅 Thank you!! 💯

jiayihu commented 3 years ago

Hi @Tlazypanda , unfortunately, I haven't been working on Keras for a while. I would suggest asking on StackOverflow maybe, although input size errors are pretty common :) Maybe try to log the dimensions of the layers to try debug the issues. Good luck with your studies!

Tlazypanda commented 3 years ago

Hola @jiayihu! Thank you for your quick response 😄 Appreciate it ✌️ I will try to post it on other platforms.