Closed Aniruddha-Deb closed 11 months ago
yeah the idea for this file was to be as simple as possible for people to understand what computation is being done if they want to re-implement it in any framework they want.
This is a bit easier to understand and parse than the cuda kernels in xformers or flashattention ;)
Lines 129-143 in
one_file_ref.py
multiplies the complete query-key matrices with each other, if we are prefilling the key-value cache. The sliding window mask is applied only after this multiplicationThis seems inefficient for prompt sizes > sliding window length, and can be improved by just using the attention implementation in
mistral/model.py
directly (which uses xformers'memory_efficient_attention
).