tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
386 stars 47 forks source link

Improve e2e perf for Flash Decode #10496

Closed caixunshiren closed 2 weeks ago

caixunshiren commented 1 month ago

Description

It has been observed that flash decode has slow e2e due to the complicated update runtime arg function. This issue tracks down the optimization of it. As shown below, it can get to as much as 110 us to update runtime args.

image

caixunshiren commented 1 month ago

Mixtral perf before sdpa:

torch_embed_initial: 0.071ms
prepare_inputs_for_inference_0: 2.407ms
python_dispatch_for_inference_0: 45.553ms
model_run_for_inference_0: 83.593ms
result_wait_for_inference_0: 38.042ms
torch_argmax_and_embed_0: 0.222ms
deallocate_tt_tensors_0: 0.032ms
prepare_inputs_for_inference_1: 2.553ms
python_dispatch_for_inference_1: 45.710ms
model_run_for_inference_1: 71.337ms
result_wait_for_inference_1: 25.630ms
torch_argmax_and_embed_1: 0.223ms
deallocate_tt_tensors_1: 0.031ms
prepare_inputs_for_inference_2: 2.557ms
python_dispatch_for_inference_2: 45.610ms
model_run_for_inference_2: 71.604ms
result_wait_for_inference_2: 25.996ms
torch_argmax_and_embed_2: 0.224ms
deallocate_tt_tensors_2: 0.033ms
prepare_inputs_for_inference_3: 2.576ms
python_dispatch_for_inference_3: 59.795ms
model_run_for_inference_3: 81.485ms
result_wait_for_inference_3: 21.691ms
torch_argmax_and_embed_3: 0.223ms
deallocate_tt_tensors_3: 0.032ms
prepare_inputs_for_inference_4: 2.554ms
python_dispatch_for_inference_4: 45.552ms
model_run_for_inference_4: 74.875ms
result_wait_for_inference_4: 29.326ms
torch_argmax_and_embed_4: 0.225ms
deallocate_tt_tensors_4: 0.035ms
prepare_inputs_for_inference_5: 2.569ms
python_dispatch_for_inference_5: 45.639ms
model_run_for_inference_5: 71.680ms
result_wait_for_inference_5: 26.042ms
torch_argmax_and_embed_5: 0.223ms
deallocate_tt_tensors_5: 0.032ms

Mixtral perf after sdpa:

torch_embed_initial: 0.073ms
prepare_inputs_for_inference_0: 1.559ms
python_dispatch_for_inference_0: 54.657ms
model_run_for_inference_0: 86.290ms
result_wait_for_inference_0: 31.634ms
torch_argmax_and_embed_0: 0.248ms
deallocate_tt_tensors_0: 0.037ms
prepare_inputs_for_inference_1: 2.704ms
python_dispatch_for_inference_1: 50.970ms
model_run_for_inference_1: 79.253ms
result_wait_for_inference_1: 28.285ms
torch_argmax_and_embed_1: 0.221ms
deallocate_tt_tensors_1: 0.037ms
prepare_inputs_for_inference_2: 1.596ms
python_dispatch_for_inference_2: 50.678ms
model_run_for_inference_2: 73.656ms
result_wait_for_inference_2: 22.978ms
torch_argmax_and_embed_2: 0.221ms
deallocate_tt_tensors_2: 0.034ms
prepare_inputs_for_inference_3: 1.775ms
python_dispatch_for_inference_3: 55.270ms
model_run_for_inference_3: 76.287ms
result_wait_for_inference_3: 20.998ms
torch_argmax_and_embed_3: 0.222ms
deallocate_tt_tensors_3: 0.039ms
prepare_inputs_for_inference_4: 2.258ms
python_dispatch_for_inference_4: 50.253ms
model_run_for_inference_4: 73.139ms
result_wait_for_inference_4: 22.887ms
torch_argmax_and_embed_4: 0.221ms
deallocate_tt_tensors_4: 0.033ms
prepare_inputs_for_inference_5: 1.553ms
python_dispatch_for_inference_5: 49.995ms
model_run_for_inference_5: 73.352ms
result_wait_for_inference_5: 23.358ms
torch_argmax_and_embed_5: 0.222ms
deallocate_tt_tensors_5: 0.033ms

Mixtral perf after sdpa w/o get runtime arg function:

torch_embed_initial: 0.134ms
prepare_inputs_for_inference_0: 1.929ms
python_dispatch_for_inference_0: 48.250ms
model_run_for_inference_0: 73.519ms
result_wait_for_inference_0: 25.268ms
torch_argmax_and_embed_0: 0.219ms
deallocate_tt_tensors_0: 0.040ms
prepare_inputs_for_inference_1: 1.905ms
python_dispatch_for_inference_1: 48.170ms
model_run_for_inference_1: 73.635ms
result_wait_for_inference_1: 25.464ms
torch_argmax_and_embed_1: 0.218ms
deallocate_tt_tensors_1: 0.041ms
prepare_inputs_for_inference_2: 1.941ms
python_dispatch_for_inference_2: 48.161ms
model_run_for_inference_2: 73.128ms
result_wait_for_inference_2: 24.966ms
torch_argmax_and_embed_2: 0.221ms
deallocate_tt_tensors_2: 0.039ms
prepare_inputs_for_inference_3: 1.925ms
python_dispatch_for_inference_3: 48.255ms
model_run_for_inference_3: 73.167ms
result_wait_for_inference_3: 24.909ms
torch_argmax_and_embed_3: 0.224ms
deallocate_tt_tensors_3: 0.033ms
prepare_inputs_for_inference_4: 1.693ms
python_dispatch_for_inference_4: 61.292ms
model_run_for_inference_4: 79.821ms
result_wait_for_inference_4: 18.529ms
torch_argmax_and_embed_4: 0.222ms
deallocate_tt_tensors_4: 0.031ms
prepare_inputs_for_inference_5: 1.676ms
python_dispatch_for_inference_5: 48.215ms
model_run_for_inference_5: 72.538ms
result_wait_for_inference_5: 24.324ms
torch_argmax_and_embed_5: 0.221ms
deallocate_tt_tensors_5: 0.031ms

Mixtral perf after sdpa w/o get runtime arg function and override runtime arg function:

prepare_inputs_for_inference_0: 1.936ms
python_dispatch_for_inference_0: 49.222ms
model_run_for_inference_0: 73.891ms
result_wait_for_inference_0: 24.668ms
torch_argmax_and_embed_0: 0.221ms
deallocate_tt_tensors_0: 0.041ms
prepare_inputs_for_inference_1: 1.883ms
python_dispatch_for_inference_1: 48.120ms
model_run_for_inference_1: 73.093ms
result_wait_for_inference_1: 24.972ms
torch_argmax_and_embed_1: 0.219ms
deallocate_tt_tensors_1: 0.039ms
prepare_inputs_for_inference_2: 1.916ms
python_dispatch_for_inference_2: 66.920ms
model_run_for_inference_2: 89.662ms
result_wait_for_inference_2: 22.740ms
torch_argmax_and_embed_2: 0.219ms
deallocate_tt_tensors_2: 0.042ms
prepare_inputs_for_inference_3: 1.894ms
python_dispatch_for_inference_3: 48.450ms
model_run_for_inference_3: 73.895ms
result_wait_for_inference_3: 25.440ms
torch_argmax_and_embed_3: 0.221ms
deallocate_tt_tensors_3: 0.039ms
prepare_inputs_for_inference_4: 1.939ms
python_dispatch_for_inference_4: 48.689ms
model_run_for_inference_4: 73.826ms
result_wait_for_inference_4: 25.136ms
torch_argmax_and_embed_4: 0.221ms
deallocate_tt_tensors_4: 0.043ms
prepare_inputs_for_inference_5: 1.899ms
python_dispatch_for_inference_5: 67.343ms
model_run_for_inference_5: 89.821ms
result_wait_for_inference_5: 22.477ms
torch_argmax_and_embed_5: 0.220ms
deallocate_tt_tensors_5: 0.044ms