pytorch / benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
BSD 3-Clause "New" or "Revised" License
817 stars 259 forks source link

Add jax installation option to tritonbench #2331

Closed xuzhao9 closed 1 week ago

xuzhao9 commented 1 week ago

We are interested in adding a few Jax pallas operator implementations to tritonbench. We could first add flash_attention.

To install JAX:

python install.py --userbenchmark triton --jax

Working on https://github.com/pytorch/benchmark/issues/2328

facebook-github-bot commented 1 week ago

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 week ago

@xuzhao9 merged this pull request in pytorch/benchmark@4720ea4f2088e0612789f21f5c930c2a31745769.