Closed xuzhao9 closed 1 week ago
We are interested in adding a few Jax pallas operator implementations to tritonbench. We could first add flash_attention.
To install JAX:
python install.py --userbenchmark triton --jax
Working on https://github.com/pytorch/benchmark/issues/2328
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@xuzhao9 merged this pull request in pytorch/benchmark@4720ea4f2088e0612789f21f5c930c2a31745769.
We are interested in adding a few Jax pallas operator implementations to tritonbench. We could first add flash_attention.
To install JAX:
Working on https://github.com/pytorch/benchmark/issues/2328