microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.1k stars 2.84k forks source link

Performance issue with beam search in onnxruntime #12078

Open OriAlpha opened 2 years ago

OriAlpha commented 2 years ago

Describe the bug I have am trying new beam search in onnxruntime, but the performace of mt5 model is poor. Following setups are considered while converting the model.

System information

Expected behavior I was expecting to behave same as pytorch model but there is big difference in outputs.

Additional context If you want to try model please use https://huggingface.co/google/mt5-large/tree/main (But its finetuned model for my application)

@tianleiwu maybe you can have look here

OriAlpha commented 2 years ago

Some more logs while model was converted:

Exporting ONNX model to ./baseline_large_encoder_decoder_init.onnxbatch_size=4 encode_sequence_length=11, max_diff=2.765655517578125e-05
batch_size=1 encode_sequence_length=2, max_diff=1.1444091796875e-05
batch_size=3 encode_sequence_length=1, max_diff=1.3113021850585938e-05
batch_size=8 encode_sequence_length=5, max_diff=2.5510787963867188e-05
PyTorch and OnnxRuntime results max difference = 2.765655517578125e-05
Exporting ONNX model to ./baseline_large_decoder.onnx
batch_size=4, encode_sequence_length=11, past_decode_sequence_length=3, max_diff=8.0108642578125e-05
batch_size=1, encode_sequence_length=2, past_decode_sequence_length=5, max_diff=0.00012254714965820312
batch_size=3, encode_sequence_length=1, past_decode_sequence_length=1, max_diff=9.524822235107422e-05
batch_size=8, encode_sequence_length=5, past_decode_sequence_length=2, max_diff=0.000392913818359375
PyTorch and OnnxRuntime results max difference = 0.000392913818359375
PyTorch and OnnxRuntime results are NOT close
tianleiwu commented 2 years ago

@OriApha, Could you use a set of inputs to evaluate the accuracy (for classification tasks) or other metrics of your task.

In above example, output difference of 4e-4 is probably acceptable for some scenario. That means the output might not be exactly same, but results could be very close in many cases.

OriAlpha commented 2 years ago

I acutally did the some evaluate, but i noticied one the issue. When i pass a sentence to the model. The model does not paraphases complete sentence, it only does half sentence. This is why i was seeing the performance issue. But may be i am passing inputs in a wrong way. The way i was passing input was in list of strings. i..e.,

inputs_sent = [sentence1,sentence2,sentence3]
for ii in range(len(inputs_sent)):
   inputs = tokenizer(src_text[ii], max_length=max_length, padding='max_length', return_tensors="pt", truncation=True)

  inputs = {
          "input_ids": input_ids.cpu().numpy().astype(np.int32),
          "max_length": np.array([max_length], dtype=np.int32),
          "min_length": np.array([min_length], dtype=np.int32),
          "num_beams": np.array([num_beams], dtype=np.int32),
          "num_return_sequences": np.array([num_return_sequences], dtype=np.int32),
          "length_penalty": np.array([length_penalty], dtype=np.float32),
          "repetition_penalty": np.array([repetition_penalty], dtype=np.float32),
          "vocab_mask": vocab_mask
      }
result = ort_session.run(None, inputs)
(batch_size, num_sequences, max_length) = sequences.shape
    ort_decoded_sequences = []
    for i in range(batch_size):
        for j in range(num_sequences):
            decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
            ort_decoded_sequences.append(decoded_sequence)
            print(decoded_sequence)
OriAlpha commented 2 years ago

So far with my inputs i could see

Torch and ORT result is  different
ORT {'test_times': 1, 'latency_variance': '0.00', 'latency_90_percentile': '5406.18', 'latency_95_percentile': '5406.18', 'latency_99_percentile': '5406.18', 'average_latency_ms': '5406.18', 'QPS': '0.18', 'parity': False}

Is there any parameters, I could change and run tests again?? @tianleiwu

OriAlpha commented 2 years ago
start testing model...
--------------------------------------------------
Test PyTorch model and beam search with huggingface transformers...
input_ids tensor([[   259,  33644,    287,    259,  13439,    628,    807,  49285,    261,
            259,  59569,  24901,    263,    783,  38186,    344,    259,   6097,
            259, 120629,    304,    259,   1616,  55094,  12842,    305,    783,
           2101,   2121, 214630,   7670,    702,  17358,    305,   8253,   9610,
            261,    287,    259, 189663,    304,  28816,    305,  18623,    261,
            305,    281,   8057,    259,  44326,    304,    259, 129447,    305,
           6174,   9610,    305,   3182,    260,      1]])
huggingface transformers outputs:
sequences tensor([[     0, 250099,    260,    259,  33644,    287,    259,  13439,    628,
            807,  49285,    261,    259,  59569,  24901,    263,    783,  38186,
            259, 120629,    304,    259,   1616,  55094,  12842, 250098,    260,
            259,  33644,    287,    259,  13439,    628,    807,  49285,    261,
            259,  59569,  24901,    263,    783,  38186,    259, 120629,    304,
            259,   1616,  55094,  12842, 250097]])
0: <extra_id_0>. Since the early 20th century, tiger populations have lost 93% of their historic range <extra_id_1>. Since the early 20th century, tiger populations have lost 93% of their historic range <extra_id_2>
--------------------------------------------------
Testing beam search with onnxruntime...
ORT outputs:
sequences [[[     0 250099    783   2101   2121 175663    345    702    259 129447
      305   8253   9610    305   3182    260    259  33644    287    259
    13439    628    807  49285    261    259  59569  24901    263    783
    38186   1388    259 120629    304    259   1616  55094  12842 250098
      259  59569  24901    263    783   2101   2121 214630   7670    702]]]
batch 0 sequence 0: <extra_id_0> have been exterminated from Southeast and Central Asia and China. Since the early 20th century, tiger populations have lost about 93% of their historic range <extra_id_1> tiger populations have been extirpated from
--------------------------------------------------
Torch Sequences:
tensor([[[     0, 250099,    260,    259,  33644,    287,    259,  13439,
             628,    807,  49285,    261,    259,  59569,  24901,    263,
             783,  38186,    259, 120629,    304,    259,   1616,  55094,
           12842, 250098,    260,    259,  33644,    287,    259,  13439,
             628,    807,  49285,    261,    259,  59569,  24901,    263,
             783,  38186,    259, 120629,    304,    259,   1616,  55094,
           12842, 250097]]])
['<extra_id_0>. Since the early 20th century, tiger populations have lost 93% of their historic range <extra_id_1>. Since the early 20th century, tiger populations have lost 93% of their historic range <extra_id_2>']
--------------------------------------------------
ORT Sequences:
tensor([[[     0, 250099,    783,   2101,   2121, 175663,    345,    702,
             259, 129447,    305,   8253,   9610,    305,   3182,    260,
             259,  33644,    287,    259,  13439,    628,    807,  49285,
             261,    259,  59569,  24901,    263,    783,  38186,   1388,
             259, 120629,    304,    259,   1616,  55094,  12842, 250098,
             259,  59569,  24901,    263,    783,   2101,   2121, 214630,
            7670,    702]]])
['<extra_id_0> have been exterminated from Southeast and Central Asia and China. Since the early 20th century, tiger populations have lost about 93% of their historic range <extra_id_1> tiger populations have been extirpated from']
--------------------------------------------------
Torch and ORT result is  different
ORT {'test_times': 1, 'latency_variance': '0.00', 'latency_90_percentile': '7002.78', 'latency_95_percentile': '7002.78', 'latency_99_percentile': '7002.78', 'average_latency_ms': '7002.78', 'QPS': '0.14', 'parity': False}

I have made more testing paraphases doesnt look good. ORGINIAL_SENTENCE :

"Since the early 20th century, tiger populations have lost at least 93% of their historic range and have been extirpated from Western and Central Asia, the islands of Java and Bali, and in large areas of Southeast and South Asia and China."

I have used mt5 large model for testing.