nshepperd / flash_attn_jax

JAX bindings for Flash Attention v2
BSD 3-Clause "New" or "Revised" License
62 stars 0 forks source link

ModuleNotFoundError: No module named `flash_attn_jax.flash_api` #2

Closed VachanVY closed 3 months ago

VachanVY commented 3 months ago

https://github.com/nshepperd/flash_attn_jax/blob/d485930c8adb3c0f34536e8bb9442b3d4b20123f/src/flash_attn_jax/flash_hlo.py#L24

nshepperd commented 3 months ago

Did you install one of the released wheels with pip?

VachanVY commented 3 months ago

How can i do that? (Tried pip install flash_att_jax but got error)

VachanVY commented 3 months ago

I actually cloned the repo, because I couldn't understand how to install it from the below

To install: For now, download the appropriate release from the releases page and install it with pip.

Can you please tell me the steps to properly install it? Thanks

VachanVY commented 3 months ago

@nshepperd could you pls tell me how I can install the released wheels with pip?

nshepperd commented 3 months ago

Go here: https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.1.0a3 And find the file that matches your python version and cuda version you're using with jax. Install it with pip install.

VachanVY commented 3 months ago

Thanks!