Closed xuzhao9 closed 1 week ago
Add the following features to ragged_hstu:
Test plan:
$ python run.py --op ragged_attention --metrics latency,tflops --mode bwd x_val hstu_triton_ragged_attention-tflops hstu_triton_ragged_attention-latency ----------------- ------------------------------------- -------------------------------------- (8, 4, 512, 2048) 0.306747 2.81939 (8, 4, 512, 2048) 1.65614 0.867936 (8, 4, 512, 2048) 2.00125 0.84768 (8, 4, 512, 2048) 2.13756 0.991968 (8, 4, 512, 2048) 1.96315 0.902976 (8, 4, 512, 2048) 1.50214 0.836192 (8, 4, 512, 2048) 1.34825 0.859936 (8, 4, 512, 2048) 1.90546 0.97408 (8, 4, 512, 2048) 1.72114 0.902368 (8, 4, 512, 2048) 2.30999 1.01107
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@xuzhao9 merged this pull request in pytorch-labs/tritonbench@c2ef66a1f568ef9b777d202f0e1058ed91ea9d0a.
Add the following features to ragged_hstu:
Test plan: