pytorch-labs / tritonbench

Tritonbench is a collection of PyTorch custom operators with example inputs to measure their performance.
BSD 3-Clause "New" or "Revised" License
20 stars 3 forks source link

Update HSTU and use the OSS wrapper for non-persisent kernels #53

Closed xuzhao9 closed 1 week ago

xuzhao9 commented 1 week ago

Add the following features to ragged_hstu:

  1. Add tflops metric
  2. Use _RaggedAttentionRelativeBiasFunction to wrap the Triton kernel
  3. Add backward

Test plan:

$ python run.py --op ragged_attention --metrics latency,tflops --mode bwd

            x_val    hstu_triton_ragged_attention-tflops    hstu_triton_ragged_attention-latency
-----------------  -------------------------------------  --------------------------------------
(8, 4, 512, 2048)                               0.306747                                2.81939
(8, 4, 512, 2048)                               1.65614                                 0.867936
(8, 4, 512, 2048)                               2.00125                                 0.84768
(8, 4, 512, 2048)                               2.13756                                 0.991968
(8, 4, 512, 2048)                               1.96315                                 0.902976
(8, 4, 512, 2048)                               1.50214                                 0.836192
(8, 4, 512, 2048)                               1.34825                                 0.859936
(8, 4, 512, 2048)                               1.90546                                 0.97408
(8, 4, 512, 2048)                               1.72114                                 0.902368
(8, 4, 512, 2048)                               2.30999                                 1.01107
facebook-github-bot commented 1 week ago

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 week ago

@xuzhao9 merged this pull request in pytorch-labs/tritonbench@c2ef66a1f568ef9b777d202f0e1058ed91ea9d0a.