pytorch / benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
BSD 3-Clause "New" or "Revised" License
817 stars 259 forks source link

Add Triton kernels with fast variable-length loop to jagged_sum benchmark #2341

Closed jananisriram closed 5 days ago

jananisriram commented 5 days ago

Summary: Add a set of Triton kernels which implements a fast, variable-length loop upon the simple fused Triton kernels created in D58549297. This diff enables looping from the beginning to the end of the specific variable-length subsection of the input nested tensor, which eliminates the extra work done by the simple fused kernel in looping over the entire range of the maximum sequence length. Specifically, this diff eliminates the need to loop over extraneous data beyond the nested tensor's jagged length, terminating the loop before it starts reading, reducing, and writing extra zeros. This diff also contains implementations for sum_then_buffer and buffer_then_sum, as seen in the simple fused implementation from D58549297.

This diff draws from a similar binary search implementation found here.

Default the sum_then_buffer parameter to 0; as shown in the Test Plan, the buffer_to_sum implementation outperforms the sum_to_buffer implementation.

Note that the benchmark selects nested tensor inputs whose padded versions do not exceed an 8 GB size limitation, as the PyTorch benchmarks are limited by the size of the padded, rather than the unpadded, nested tensors.

Reviewed By: davidberard98

Differential Revision: D59026460

facebook-github-bot commented 5 days ago

This pull request was exported from Phabricator. Differential Revision: D59026460

facebook-github-bot commented 5 days ago

This pull request has been merged in pytorch/benchmark@ec0909d8041945ad5d78412fc901dc7694aa54cc.