simplysameer333 / MachineLearning

1 stars 1 forks source link

generate_summary #9

Open simplysameer333 opened 5 years ago

simplysameer333 commented 5 years ago

vocab_to_int_pickle_filename = "vocab_to_int.pickle" int_to_vocab_pickle_filename = "int_to_vocab.pickle"

vocab_to_int, int_to_vocab = covert_vocab_to_int(word_counts, embeddings_index)

# persist vocab_to_int for use in generate stage
with open(config.base_path + config.vocab_to_int_pickle_filename, 'wb') as handle:
    pickle.dump(vocab_to_int, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(config.base_path + config.vocab_to_int_pickle_filename, 'wb') as handle:
    pickle.dump(int_to_vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)

===================

import pickle import numpy as np import tensorflow as tf import data_processing as dp import config

def text_to_cleanedup(text, vocab_to_int): '''Cleanup text before passing it to inference_stage'''

text = dp.clean_text(text)
return [vocab_to_int.get(word, vocab_to_int['<UNK>']) for word in text.split()]

def inference_stage(input_cleaned_test): checkpoint = "./best_model.ckpt"

loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
    # Load saved model
    loader = tf.train.import_meta_graph(checkpoint + '.meta')
    loader.restore(sess, checkpoint)

    input_data = loaded_graph.get_tensor_by_name('input:0')
    logits = loaded_graph.get_tensor_by_name('predictions:0')
    text_length = loaded_graph.get_tensor_by_name('text_length:0')
    summary_length = loaded_graph.get_tensor_by_name('summary_length:0')
    keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')

    # Multiply by batch_size to match the model's input parameters
    generated_logits = sess.run(logits, {input_data: [input_cleaned_test] * config.batch_size,
                                      summary_length: [np.random.randint(5, 8)],
                                      text_length: [len(input_cleaned_test)] * config.batch_size,
                                      keep_prob: 1.0})[0]

return generated_logits

def main():

''' Loading vocab_to_int & int_to_vocab persisted during vectorization '''
with open(config.base_path + config.vocab_to_int_pickle_filename, 'rb') as handle:
    vocab_to_int = pickle.load(handle)

with open(config.base_path + config.int_to_vocab_pickle_filename, 'rb') as handle:
    int_to_vocab = pickle.load(handle)

#input_text = ""
#cleaned_text = text_to_cleanedup(input_text, vocab_to_int)

with open(config.base_path + config.articles_pickle_filename, 'rb') as handle:
    clean_articles = pickle.load(handle)

random = np.random.randint(0, len(clean_articles))
input_sentence = clean_articles[random]
cleaned_text = text_to_cleanedup(input_sentence, vocab_to_int)

generated_logits = inference_stage(cleaned_text)

# Remove padding flag from the text
pad = vocab_to_int["<PAD>"]

print('Original Text: ', input_sentence)

print('Text info ')
print('Word Ints:    {} '.format([i for i in cleaned_text]))
print('Input Clean Words: {} '.format(" ".join([int_to_vocab[i] for i in cleaned_text])))

print('Generated Summary')
print('Word Ints:       {} '.format([i for i in generated_logits if i != pad]))
print('Generated Words: {} '.format(" ".join([int_to_vocab[i] for i in generated_logits if i != pad])))

'''-------------------------main------------------------------''' main()