erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
192 stars 23 forks source link

Flash attention #16

Closed vwxyzjn closed 1 year ago

vwxyzjn commented 1 year ago

Thanks for the repo. I noticed the README mentioned this repo supports flash attention https://github.com/erfanzar/EasyDeL#available-models-are, which is done via https://github.com/google/flaxformer/blob/ee62754ebe5a5eeb111493622de5537133822e3e/flaxformer/components/attention/memory_efficient_attention.py#L60. I was wondering if you had noticed any speed improvement or memory usage reduction, compared to, say https://github.com/jax-ml/jax-triton.

erfanzar commented 1 year ago

your welcome and thanks for using the repo but i have to say no i didn't do any fancy changes but iv done some in coding but i haven't test or do any benchmark for that