Summary:
Add Triton kernel implementation to jagged_sum operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions (B, *, M), where * is the ragged dimension. It loads in blocks of the values tensor along its last dimension M, reduces each block of variable length along its first dimension *, and stores each of B reductions in an output tensor of shape (B, M).
This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad.
Summary: Add Triton kernel implementation to
jagged_sum
operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions(B, *, M)
, where*
is the ragged dimension. It loads in blocks of thevalues
tensor along its last dimensionM
, reduces each block of variable length along its first dimension*
, and stores each ofB
reductions in an output tensor of shape(B, M)
.This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad.
Reviewed By: davidberard98
Differential Revision: D58549297