ShannonAI / service-streamer

Boosting your Web Services of Deep Learning Applications.
Apache License 2.0
1.22k stars 187 forks source link

可以配合tornado使用么? #81

Open tengmao opened 3 years ago

tengmao commented 3 years ago

作者您好, 我试了一下讲flask替换为tornado, batchsize一直是1, log如下: INFO:service_streamer.service_streamer:[gpu worker 271141] run_once batch_size: 1 start_at: 1597397389.1931512 spend: 2.809465646743774

是不能配合tornado使用么?

tengmao commented 3 years ago

代码如下:


import tornado.web
import tornado.ioloop
import tornado.httpserver
from tornado.escape import json_decode
import json
from tornado.options import define, options

import torch
import json
from transformers import BertTokenizer,BertForSequenceClassification
import torch.nn as nn

from service_streamer import ThreadedStreamer

class BertgModel(object):
    def __init__(self,model_path):
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()
    def predict(self, batch_input):

        pt_batch = self.tokenizer(batch_input,
                                  padding="max_length",
                                  truncation="only_first",
                                  return_tensors="pt",
                                  max_length =128
                                )

        pt_batch = pt_batch.to(self.device)

        with torch.no_grad():
            outputs = self.model(**pt_batch)
            prob = torch.softmax(outputs[0], dim=1)
            predict = torch.max(prob.data, 1)[1].cpu().numpy()

        return predict

print("load model......")
model = BertgModel("../model/")
print("load sucess...")

streamer = ThreadedStreamer(model.predict, batch_size=64, max_latency=2)

class IndexHandler(tornado.web.RequestHandler):
    def post(self, *args, **kwargs):
        print("post......")
        params = json_decode(self.request.body)
        inputs = [params["query"]]
        print(inputs)
        outputs = streamer.predict(inputs)
        print(outputs)
        result = {}
        result["result"] = "result"
        result = json.dumps(result)
        self.write(result)

if __name__ == '__main__':
    print("start....")
    app = tornado.web.Application([
        (r"/", IndexHandler)
    ])

    define("port", default=1818, help="run on the given port", type=int)

    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)   
    tornado.ioloop.IOLoop.current().start()
yangtianyu92 commented 2 years ago

好像不支持异步,你有解决方案吗

Meteorix commented 2 years ago

可能不能直接支持,tornado有自己的事件循环,跟这里面的thread会有冲突。你可能需要自己将tornado接收的请求放到一个任务队列,再用streamer进行batching