Closed VachanVY closed 3 months ago
Did you install one of the released wheels with pip?
How can i do that? (Tried pip install flash_att_jax but got error)
I actually cloned the repo, because I couldn't understand how to install it from the below
To install: For now, download the appropriate release from the releases page and install it with pip.
Can you please tell me the steps to properly install it? Thanks
@nshepperd could you pls tell me how I can install the released wheels with pip?
Go here: https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.1.0a3
And find the file that matches your python version and cuda version you're using with jax. Install it with pip install
.
Thanks!
https://github.com/nshepperd/flash_attn_jax/blob/d485930c8adb3c0f34536e8bb9442b3d4b20123f/src/flash_attn_jax/flash_hlo.py#L24