pender / chatbot-rnn

A toy chatbot powered by deep learning and trained on data from Reddit
MIT License
899 stars 370 forks source link

Remove the beam search generator #38

Closed geroale closed 6 years ago

geroale commented 6 years ago

Hi @pender , first of all, thank you so much for your repo. I found it really helpful for learning about RNNs, the code is enough clear, the model based on reddit well trained and everything is cool.

I have only one curiosity: is it possible to generate the answer instantly from the model, without the char after char generation? I think that I have understand that this generation effect is in the beam search generator but I can't fully get how it works. It would be great if the model could write the answer in one instant, or add a parameter by which the user can decide what type of generation use.

Thanks for your work and everything. Alessandro.

pender commented 6 years ago

Hi Alessandro, generation for this type of model has to occur one character at a time, because each character is chosen based on all of the previous characters. You can disable beam search by typing "--beam_width 1" during chat, which will make generation faster (and much worse), but it will still choose one character at a time.

shubhank008 commented 6 years ago

@geroale just wanted to share something similar I wanted and how I did it, although it does not return the output instantly, it does return it as a whole (sentence) at once. Usefull for creating APIs or integrating the output in your own code

def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature, topn, input_text):
    states = initial_state_with_relevance_masking(net, sess, relevance)
    while True:
        #user_input = input('\n> ')
        user_input = input_text
        start_time = time.time()
        user_command_entered, reset, states, relevance, temperature, topn, beam_width = process_user_command(
            user_input, states, relevance, temperature, topn, beam_width)
        if reset: states = initial_state_with_relevance_masking(net, sess, relevance)
        if not user_command_entered:
            beam_width = 1
            states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>"))
            computer_response_generator = beam_search_generator(sess=sess, net=net,
                initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
                early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
                forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>'],
                                'temperature':temperature, 'topn':topn})
            #print(chars)
            out_chars = []
            out = []
            for i, char_token in enumerate(computer_response_generator):
                out_chars.append(chars[char_token])
                out.append(possibly_escaped_char(out_chars))
                #print(possibly_escaped_char(out_chars), end='', flush=True)
                states = forward_text(net, sess, states, relevance, vocab, chars[char_token])
                if i >= max_length: break
            #print("".join(out))
            #print("--- %s seconds ---" % (time.time() - start_time))
            return "".join(out)
            states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> "))

You just store the characters in var[] and once the loop ends, you join them in a single statement. PS: My code modification is to return the output, if you want to just print it, comment return line and uncomment print.