pytorch / benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
BSD 3-Clause "New" or "Revised" License
854 stars 275 forks source link

Add jax pallas kernel examples for tritonbench #2328

Open xuzhao9 opened 3 months ago

xuzhao9 commented 3 months ago

Add https://github.com/google/jax/blob/main/jax/experimental/mosaic/gpu/examples/flash_attention.py from jax to Tritonbench