Closed geroale closed 6 years ago
Hi Alessandro, generation for this type of model has to occur one character at a time, because each character is chosen based on all of the previous characters. You can disable beam search by typing "--beam_width 1" during chat, which will make generation faster (and much worse), but it will still choose one character at a time.
@geroale just wanted to share something similar I wanted and how I did it, although it does not return the output instantly, it does return it as a whole (sentence) at once. Usefull for creating APIs or integrating the output in your own code
def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature, topn, input_text):
states = initial_state_with_relevance_masking(net, sess, relevance)
while True:
#user_input = input('\n> ')
user_input = input_text
start_time = time.time()
user_command_entered, reset, states, relevance, temperature, topn, beam_width = process_user_command(
user_input, states, relevance, temperature, topn, beam_width)
if reset: states = initial_state_with_relevance_masking(net, sess, relevance)
if not user_command_entered:
beam_width = 1
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>"))
computer_response_generator = beam_search_generator(sess=sess, net=net,
initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>'],
'temperature':temperature, 'topn':topn})
#print(chars)
out_chars = []
out = []
for i, char_token in enumerate(computer_response_generator):
out_chars.append(chars[char_token])
out.append(possibly_escaped_char(out_chars))
#print(possibly_escaped_char(out_chars), end='', flush=True)
states = forward_text(net, sess, states, relevance, vocab, chars[char_token])
if i >= max_length: break
#print("".join(out))
#print("--- %s seconds ---" % (time.time() - start_time))
return "".join(out)
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> "))
You just store the characters in var[] and once the loop ends, you join them in a single statement. PS: My code modification is to return the output, if you want to just print it, comment return line and uncomment print.
Hi @pender , first of all, thank you so much for your repo. I found it really helpful for learning about RNNs, the code is enough clear, the model based on reddit well trained and everything is cool.
I have only one curiosity: is it possible to generate the answer instantly from the model, without the char after char generation? I think that I have understand that this generation effect is in the beam search generator but I can't fully get how it works. It would be great if the model could write the answer in one instant, or add a parameter by which the user can decide what type of generation use.
Thanks for your work and everything. Alessandro.