Open Rathen121 opened 5 years ago
You can sort of get around this by introducing your own newline symbol and using it in your input. The model will often pick up on the pattern and use the symbol itself. I would also like the ability to use newlines, though.
You can use \n
in your code. For example, you might read in a text file with line endings and it just works. This may require slightly modifying the code, but if I can do it, anyone can!
Hey please I am trying to do the same as you did could you please tell me how can I modify the code so that I can make it take multiple input, Thanks in advance! I know I am a late on this but I am currently working on it and I am really stuck and couldn't find any help so far .
This is some code that I wrote based on the examples in this repository. Just put it in the same folder as the rest of the source. You will likely want to change the directory for the model, since I used a personal one that I fine tuned for my purposes.
import json
import os
import numpy as np
import tensorflow as tf
import model as model
import sample as sample
import encoder as encoder
class Generator():
def __init__(self, sess, length=40, temperature=0.9, top_k=40):
seed = None
batch_size=1
model_path='models/sanjeev-model-curated'
self.sess = sess
self.enc = encoder.get_encoder(model_path, '') # Note that the '' is to trick the encoder since we have the model name in the path
hparams = model.default_hparams()
with open(os.path.join(model_path, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
self.context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
self.output = sample.sample_sequence(
hparams=hparams, length=length,
context=self.context,
batch_size=batch_size,
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(model_path)
saver.restore(self.sess, ckpt)
def generate(self, prompt):
context_tokens = self.enc.encode(prompt)
out = self.sess.run(self.output, feed_dict={
self.context: [context_tokens for _ in range(1)]
})[:, len(context_tokens):]
text = self.enc.decode(out[0])
return text
You just need to pass in a new session (used in a context manager due to a weirdness with eager execution). An example of how to use this is given below in a simple flask application being served with waitress.
from generator import Generator
import tensorflow as tf
import json
from getpass import getpass
import random
from flask import Flask, request, Response, make_response, jsonify
from waitress import serve
import regex
with tf.Session(graph=tf.Graph()) as sess:
generator = Generator(sess)
re = regex.compile('[a-zA-Z]')
def PruneResult(text):
text = text.split('\n')
if re.search(text[0]):
return text[0]
else:
for t in text[1:]:
mess = t.split(':')
mess = ':'.join(mess[1:])
if re.search(mess):
return mess
return 'Why are you so confusing, humans?'
def onMessage(data):
print('Raw Input: ' + str(data))
new_data = {}
prompt = ""
for key in data.keys():
new_data[int(key)] = data[key]
data = new_data
del new_data
for key in sorted(data.keys()):
prompt += f"{data[key][0]}: {data[key][1]}\n"
prompt += "Damien: "
print('Prompt: ' + prompt)
result = generator.generate(prompt)
print('Uncut result: ' + result)
mess = PruneResult(result)
print('Message: ' + mess)
return mess
print('Time to start')
app = Flask(__name__)
@app.route('/', methods=['POST'])
def respond():
try:
authHeader = request.headers['Authorization']
except:
return make_response(jsonify({}), 400)
if authHeader != "Bearer CHANGE_THIS": # Change this key for security, make sure to remember it to add it to your header
return make_response(jsonify({}), 400)
if request.remote_addr != "10.0.0.1" and request.remote_addr != "127.0.0.1": # Change these addresses to the addresses you want to accept requests from
return make_response(jsonify({}), 400)
if request.is_json:
return make_response(jsonify(onMessage(request.json)), 201)
else:
print(request)
return make_response(jsonify({}), 400)
if __name__ == "__main__":
serve(app, host='10.0.0.2', port=5000) # Change the host to 0.0.0.0 to make it public or 127.0.0.1 so it only works on your machine
An example of a curl command that would trigger this would be curl -H "Authorization: Bearer CHANGE_THIS" -H "Content-Type: application/json" -X POST -d '{"1": ["James", "Hi Micheal, how are you?"], "2":["Micheal","Good thanks, just finished my project"],"3":["James","Thanks good, I hope you enjoyed it"]}' http://10.0.0.2:5000/
I tricked with multiline notation:
prompt = ('"""' + "\n" + multiline_text + "\n" +'"""')
Currently I have found no way to enter multiple paragraphs or a list format. Enter and all other newline methods I've tried do not work.