Closed mdmustafizurrahman closed 4 years ago
Hi @mdmustafizurrahman,
I think you might not be profiling the code correctly. The bottleneck should be getting the predictions from the estimator, rather than the post-processing code.
I have created the code in a way where the estimator is loaded only once, but for each query I call the estimator to predict. The query comes in as a online setup, so every time a query come, I convert it to TAPAS input format and feed it to estimator. So it is not like I am loading the estimator every time for a new query. I have reduced the execution time a lot by doing this. But still on average it takes 15 seconds. On top of that for some queries, I keep getting Sequence too long error.
15s sounds about right if you are running prediction on CPU for a BERT large model + TF example conversion for a single example.
Hey I have a question regarding the execution time again, your run_main_task.py contains estimator.predict(input_fn=predict_input_fn) My understanding is the predict is loading the graph weights every time for I ask a prediction for a single example [I am using batch size =1 because the query comes in online fashion]. That is why it is taking almost 15 seconds per example, Am I correct?
Is there any way we can re-use the loaded graph for all the subsequent queries that will come in online setup?
Hi,
I was trying to run your code on SQA model. I ran the command without running any fine-tuning (basically your NoteBook example)
python run_task_main.py --task="SQA" --output_dir="results" --noloop_predict --test_batch_size={len(queries)} --tapas_verbosity="ERROR" --compression_type= --init_checkpoint="tapas_sqa_base/model.ckpt" --bert_config_file="tapas_sqa_base/bert_config.json" --mode="predict"
In your run_task_main.py there is a line result = estimator.predict(input_fn=predict_input_fn) and then you did the following:
for prediction in result: question_id = _get_question_id(prediction) max_width = prediction["column_ids"].max() max_height = prediction["row_ids"].max()
After analysing the code, I realized that this is post-processing of the results variables and apparently this is O(n^3) or more execution time to get the cell coordinates. Am I right? If so, it is not deployable for a real-time online question answering right? For one example, with 22 rows in my table it took almost 100 seconds to get the cell predictions.