lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

speed up with flash attn in A6000? #47

Closed wac81 closed 1 year ago

wac81 commented 1 year ago

please check it. https://www.reddit.com/r/StableDiffusion/comments/xmr3ic/speed_up_stable_diffusion_by_50_using_flash/

but it's not speed up use palm model with flash attn param in A6000 in my case.

conceptofmind commented 1 year ago

PyTorch 2.0 Flash Attention requires a SM80 architecture. The A6000 has a SM86 architecture. It is not currently supported. And just to clarify again, you can not use a dim_head above 128.

wac81 commented 1 year ago

PyTorch 2.0 Flash Attention requires a SM80 architecture. The A6000 has a SM86 architecture. It is not currently supported. And just to clarify again, you can not use a dim_head above 128.

thank you a lot