lucko515 / chatbot-startkit

This repository holds files for the simple chatbot wrote in TensorFlow 1.4, with attention mechanism and bucketing.
59 stars 38 forks source link

How to test the Saved model. #4

Open jainsanmati opened 6 years ago

jainsanmati commented 6 years ago

I have completed the training part and saved the model (checkpoints). It would be nice if you can update the code for testing the model. Thanks

lok63 commented 6 years ago

Has anyone wrote a code for for testing the model and reusing the checkpoints?

caballeto commented 6 years ago

Here some junk code for setting a chat, if it is what you mean on 'testing'. There cab be some misnames in functions with this repo, but in general this code works. data_utils.py - file with data preproccessing (was renamed from cornell_data ...) clean function creates clean questions and answers, using clean_text which cleans uses regexp. np.ones creates a fake batch with lengthes [decoder_length, 1, 1, 1, ...], it has length of batch_size. for bucket in config.BUCKETS: - loop used to get the bucket the questions belongs to.

import tensorflow as tf
import numpy as np
import config
from data_utils import *
from model import Chatbot

def str_to_int(question, word_to_id):
    question = clean_text(question)
    return [word_to_id.get(word, word_to_id['<OUT>']) for word in question.split()]

if __name__ == "__main__":
    print('1. Creating vocabulary.')
    clean_questions, clean_answers = clean()
    vocab, word_to_id, id_to_word = create_vocabulary(clean_questions, clean_answers)
    checkpoint = 'checkpoint/chatbot_1'
    print('2. Building model.')
    model = Chatbot(config.LEARNING_RATE,
                    config.BATCH_SIZE,
                    config.ENCODING_EMBED_SIZE,
                    config.DECODING_EMBED_SIZE,
                    config.RNN_SIZE,
                    config.NUM_LAYERS,
                    len(vocab),
                    word_to_id,
                    config.CLIP_RATE)

    print('3. Starting session.')
    session = tf.Session()
    saver = tf.train.Saver()
    saver.restore(session, checkpoint)

    while True:
        question = input("Question: ")
        if question == 'exit' or question == 'Exit':
            break
        length = len(question)
        for bucket in config.BUCKETS:
            if length <= bucket[0] and length <= bucket[1]:
                enc_len = np.ones(config.BATCH_SIZE, dtype=np.int32)
                dec_len = np.ones(config.BATCH_SIZE, dtype=np.int32)
                enc_len[0] = bucket[0]
                dec_len[0] = bucket[1]
                length = bucket[0]
                break

        question = str_to_int(question, word_to_id)
        question += [word_to_id['<PAD>']] * (length - len(question))
        fake_batch = np.zeros((config.BATCH_SIZE, length))
        fake_batch[0] = question
        prediction = session.run(model.predictions, feed_dict={model.inputs: fake_batch, model.keep_probs: 0.5, model.encoder_seq_len: enc_len, model.decoder_seq_len: dec_len})[0]
        answer = ''
        for x in prediction:
            if id_to_word[x] == 'i':
                token = " I"
            elif id_to_word[x] == '<EOS>':
                token = '.'
            elif id_to_word[x] == '<OUT>':
                token = 'out'
            else:
                token = ' ' + id_to_word[x]
            answer += token
            if token == '.':
                break
        print("Chatbot : " + answer)

P. S. I know this code is total junk, but it works for me. So if you get use of it, it'll be great. Hope it will help.

ShabbirMarfatiya commented 3 years ago

Hi @caballeto I have changed clean() to clean_data(), create_vocabulary() to create_vocab() according to the function names in cornell_data_utils.py but I didn't understand the clean_text(). Can you please tell me about it?

caballeto commented 3 years ago

Hi @caballeto I have changed clean() to clean_data(), create_vocabulary() to create_vocab() according to the function names in cornell_data_utils.py but I didn't understand the clean_text(). Can you please tell me about it?

That was 2 years ago, but I would guess that this function uses re.sub to remove punctuation and probably transform sentence into lower case.

ShabbirMarfatiya commented 3 years ago

Hi @caballeto I have changed clean() to clean_data(), create_vocabulary() to create_vocab() according to the function names in cornell_data_utils.py but I didn't understand the clean_text(). Can you please tell me about it?

That was 2 years ago, but I would guess that this function uses re.sub to remove punctuation and probably transform sentence into lower case.

@caballeto I have tried cornell_tokenizer() but it gives me KeyError . Can you please try to remember this?

caballeto commented 3 years ago

Try something like this

import re

def clean_data(s):
  return re.sub(r'[^\w\s]','',s.lower())
ShabbirMarfatiya commented 3 years ago
def clean_data(s):
  return re.sub(r'[^\w\s]','',s.lower())

Thanks, @caballeto, for helping me. I have figured out the error and solved it. Your code works perfectly.

khritish29 commented 3 years ago

i am getting error 'clean' is not defined can someone share the updated testing model code pleaseee

ShabbirMarfatiya commented 3 years ago

i am getting error 'clean' is not defined can someone share the updated testing model code pleaseee

import tensorflow as tf import numpy as np import config from cornell_data_utils import * from model_utils import Chatbot

def str_to_int(question, word_to_id): question = cornell_tokenizer(question) return [word_to_id.get(word, word_to_id['']) for word in question.split()]

if name == "main": print('1. Creating vocabulary.') clean_questions, clean_answers = clean_data() vocab, word_to_id, id_to_word = create_vocab(clean_questions, clean_answers) checkpoint = 'chatbot_1.ckpt' print('2. Building model.') model = Chatbot(config.LEARNING_RATE, config.BATCH_SIZE, config.ENCODING_EMBED_SIZE, config.DECODING_EMBED_SIZE, config.RNN_SIZE, config.NUM_LAYERS, len(vocab), word_to_id, config.CLIP_RATE)

print('3. Starting session.')
session = tf.Session()
saver = tf.train.Saver()
saver.restore(session, checkpoint)

while True:
    question = input("You: ")
    if question == 'exit' or question == 'Exit':
        break
    length = len(question)
    for bucket in config.BUCKETS:
        if length <= bucket[0] and length <= bucket[1]:
            enc_len = np.ones(config.BATCH_SIZE, dtype=np.int32)
            dec_len = np.ones(config.BATCH_SIZE, dtype=np.int32)
            enc_len[0] = bucket[0]
            dec_len[0] = bucket[1]
            length = bucket[0]
            break

    question = str_to_int(question, word_to_id)
    question += [word_to_id['<PAD>']] * (length - len(question))
    fake_batch = np.zeros((config.BATCH_SIZE, length))
    fake_batch[0] = question
    prediction = session.run(model.predictions, feed_dict={model.inputs: fake_batch, model.keep_probs: 0.5, model.encoder_seq_len: enc_len, model.decoder_seq_len: dec_len})[0]
    answer = ''
    for x in prediction:
        if id_to_word[x] == 'i':
            token = " I"
        elif id_to_word[x] == '<EOS>':
            token = '.'
        elif id_to_word[x] == '<UNK>':
            token = 'out'
        else:
            token = ' ' + id_to_word[x]
        answer += token
        if token == '.':
            break
    print("Chatbot : " + answer)
khritish29 commented 3 years ago

Thanks for the quick reply. Now i am getting this error question = cornell_tokenizer(question) ^ IndentationError: expected an indented block

ShabbirMarfatiya commented 3 years ago

Thanks for the quick reply. Now i am getting this error question = cornell_tokenizer(question) ^ IndentationError: expected an indented block

It's indentation error. Check the code of cabaletto and correct the indentation.

khritish29 commented 3 years ago

Thanks a lot the indentation errors have been solve. Now i am getting the error based on the checkpoint location i tried a lot to change the path. Can you please share the folder directly. Thanks a lot.

NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for .ipynb_checkpoints/.ipynb_checkpoints/chatbot-checkpoint-checkpoint

khritish29 commented 3 years ago

i just need the file checkpoint/chatbot_1

khritish29 commented 3 years ago

please tell me how much hours it took to complete train the model