philschmid / deep-learning-pytorch-huggingface

MIT License
580 stars 138 forks source link

Flash attention #22

Closed philschmid closed 11 months ago

tmm1 commented 11 months ago

hi, thanks for documenting this.

i'm curious if you were able to train successfully with peft + flash attention?

i keep seeing loss spikes after a few iterations

tmm1 commented 11 months ago

i tried on another machine and it works fine there, so something else is going on. i will continue investigation in https://github.com/Dao-AILab/flash-attention/issues/311

philschmid commented 11 months ago

@tmm1 what machine have you used where it is not working?

tmm1 commented 11 months ago

it is a local machine with 3090. I got it working by starting with a fresh CUDA 11.8 conda environment. the flash-attn tests were failing in the broken env.