replace the model setting self_attn_type by a RunningConfig setting self_attn_backend = "flash, pytorch"
In fact:
At training, when using Rotary or Legacy Position Encoding, using flash or pytroch sdpa is almost the same. if alibi or max_relative_positions then it will use "manual" matmul anyway.
At inference: using "flash" instead of "pytorch" will trigger the use of flash_func_with_kvcache which is much faster and not implemented in pytorch 2.3 yet
replace the model setting self_attn_type by a RunningConfig setting self_attn_backend = "flash, pytorch"
In fact: At training, when using Rotary or Legacy Position Encoding, using flash or pytroch sdpa is almost the same. if alibi or max_relative_positions then it will use "manual" matmul anyway.
At inference: using "flash" instead of "pytorch" will trigger the use of flash_func_with_kvcache which is much faster and not implemented in pytorch 2.3 yet