pytorch / benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
BSD 3-Clause "New" or "Revised" License
817 stars 259 forks source link

Add simple fused Triton kernel for jagged_sum operator #2322

Closed jananisriram closed 1 week ago

jananisriram commented 1 week ago

Summary: Add Triton kernel implementation to jagged_sum operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions (B, *, M), where * is the ragged dimension. It loads in blocks of the values tensor along its last dimension M, reduces each block of variable length along its first dimension *, and stores each of B reductions in an output tensor of shape (B, M).

This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad.

Reviewed By: davidberard98

Differential Revision: D58549297

facebook-github-bot commented 1 week ago

This pull request was exported from Phabricator. Differential Revision: D58549297

facebook-github-bot commented 1 week ago

This pull request has been merged in pytorch/benchmark@a3af77282f8be6e7271afa4ce9adbd8065f3627c.