google-research / bert

TensorFlow code and pre-trained models for BERT
https://arxiv.org/abs/1810.04805
Apache License 2.0
38.16k stars 9.6k forks source link

how can "Estimater.predict" do a real-time prediction ? #790

Open ApexPredator1 opened 5 years ago

ApexPredator1 commented 5 years ago

i want to use bert for sentiment classification mission, i fine-tuned bert on a dataset and get an available model, and then, i found it is very slow to predict one sample, someone said the reason is that Estimater.predict will reload graph on each call, which almost spent five seconds!!!!! god!! image

some people say using tf.data.Dataset.from_generator() is a resolution, but i still cant succeed after trying many times

here is part of my code:

    def input_gen(self):
        while True:
            text = "这是一个测试"  # input()
            examples = [InputExample(guid=uuid.uuid4(), text_a=text, text_b=None, label="0")]
            features = BertSentiment.examples_to_features(examples, self.label_list, self.max_seq_length, self.tokenizer)

            all_input_ids = []
            all_input_mask = []
            all_segment_ids = []
            all_label_ids = []

            for feature in features:
                all_input_ids.append(feature.input_ids)
                all_input_mask.append(feature.input_mask)
                all_segment_ids.append(feature.segment_ids)
                all_label_ids.append(feature.label_id)

            num_examples = len(features)

            input_ids = tf.constant(all_input_ids, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
            input_mask = tf.constant(all_input_mask, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
            segment_ids = tf.constant(all_segment_ids, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
            label_ids = tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32)
            print("here1:{}".format(input_ids.shape))

            yield input_ids, input_mask, segment_ids, label_ids

    def input_fn(self, params):
        """The actual input function."""
        d = tf.data.Dataset.from_generator(self.input_gen, output_types=(tf.int32, tf.int32, tf.int32, tf.int32))
        d = d.batch(batch_size=params["batch_size"], drop_remainder=False)
        iterator = d.make_one_shot_iterator()
        input_ids, input_mask, segment_ids, label_ids = iterator.get_next()
        return {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, 'label_ids': label_ids}

    def predict(self):
        result = self.estimator.predict(input_fn=self.input_fn)
        result = list(result)
        print(result)
ApexPredator1 commented 5 years ago

Well~,After a day's hard work, I found the answer: https://hanxiao.github.io/2019/01/02/Serving-Google-BERT-in-Production-using-Tensorflow-and-ZeroMQ/

Thank you very much Dr. Han Xiao

sarnikowski commented 5 years ago

Here is a simple alternative: Instead of using tf.estimator.predict, you can export your model and use predictor from tensorflow.contrib. This way you avoid loading the graph everytime. It roughly works like this:

First export your model after training:

estimator.export_saved_model(model_dir, serving_input_receiver_fn)

Then when you want to use it for inference reload your model using predictor:

from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(model_dir)
result = predict_fn(...)

This should speed your inference up significantly.

XuJianzhi commented 5 years ago

Here is a simple alternative: Instead of using tf.estimator.predict, you can export your model and use predictor from tensorflow.contrib. This way you avoid loading the graph everytime. It roughly works like this:

First export your model after training:

estimator.export_saved_model(model_dir, serving_input_receiver_fn)

Then when you want to use it for inference reload your model using predictor:

from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(model_dir)
result = predict_fn(...)

This should speed your inference up significantly.

hello, what is serving_input_receiver_fn ? like this? https://github.com/bigboNed3/bert_serving ???

sarnikowski commented 5 years ago

@XuJianzhi I wrote an example on how to create this object: https://github.com/sarnikowski/bert_in_a_flask/blob/b981e1eca064ccac09e0c91406f45ff9517c4c35/src/utils/utils_bert.py#L47

XuJianzhi commented 5 years ago

@XuJianzhi I wrote an example on how to create this object: https://github.com/sarnikowski/bert_in_a_flask/blob/b981e1eca064ccac09e0c91406f45ff9517c4c35/src/utils/utils_bert.py#L47

thank you. i tried successful. BUT tf.contrib.predictor.from_saved_model take longer time than traditional estimator.predict ! Is it right? Or i did something wrong ??

sarnikowski commented 5 years ago

I am not sure i understand your question. The following line loads the model (which is slow):

predictor = tf.contrib.predictor.from_saved_model(somedir)

However this line is used for inference, which is much faster:

predictor(X)

Did you compare estimator.predict with predictor?

XuJianzhi commented 5 years ago

I am not sure i understand your question. The following line loads the model (which is slow):

predictor = tf.contrib.predictor.from_saved_model(somedir)

However this line is used for inference, which is much faster:

predictor(X)

Did you compare estimator.predict with predictor?

` estimator.export_savedmodel(export_dir_base=FLAGS.output_dir, serving_input_receiver_fn=serving_input_fn) predict_fn = tf.contrib.predictor.from_saved_model(os.path.join(FLAGS.output_dir, '1572514511'))

...

estimator.predict(xxx) predict_fn(yyy) ` I mean, the average time of predicting per sample, predict_fn(yyy) spend twice time than estimator.predict(xxx) . Is it right ?

XuJianzhi commented 5 years ago

I am not sure i understand your question. The following line loads the model (which is slow):

predictor = tf.contrib.predictor.from_saved_model(somedir)

However this line is used for inference, which is much faster:

predictor(X)

Did you compare estimator.predict with predictor?

By the way, I use cpu. Is It the reason ??

sayak1711 commented 4 years ago

@ApexPredator1 Can you please provide code example on how you used bert-as-service for your classification task? I am currently using google-research ' s code which uses tpuestimator. I am running it on CPU. It is slow for me. I wish to know if I can use bert-as-service to do the same faster.

ApexPredator1 commented 4 years ago

@sayak1711 i found the answer from here: https://hanxiao.github.io/2019/01/02/Serving-Google-BERT-in-Production-using-Tensorflow-and-ZeroMQ/

but i did not directly use Dr. Han Xiao's bert-as-service project, instead, i rewrite bert's code based on the idea of Dr. Han Xiao's code in the above link, for example, the following link is a simple demo of my real-time text sentiment analysis model, model file is too large, so i didt upload, you should know the main idea through the code https://github.com/ApexPredator1/temporary-project/blob/master/bert_sentiment_analysis.py

In addition, I think it is convenient to use Dr. Han Xiao's bert-as-service project directly

jageshmaharjan commented 4 years ago

Maybe check this out: BERT Inference https://medium.com/analytics-vidhya/bert-rest-inference-from-the-fine-tuned-model-4b1f31151f97

sayak1711 commented 4 years ago

Also worth converting your bert model to onnx version and if you have a gpu then use onnxruntime-gpu for inference on it.