Closed ThisisBillhe closed 1 month ago
cc @fmassa
Hi,
There seems to be an issue with the tutorial indeed. Your example has two issues:
mask
as an argument to ScaledDotProduct.forward
while the expected value is att_mask
(yeah, I know, this should have been a loud error)mask
- if you wanted to use sparsity you'd have to convert the mask
to SparseCS
tensor, which enables sparse attention computation.By fixing those two issues, you get expected results:
Sparse - Peak memory use: 142MB
Dense - Peak memory use: 323MB
Sparse - Peak memory use: 577MB
That being said, xformers.components
(including the code you are using) are deprecated in favor of xformers.ops.memory_efficient_attention
and will be soon removed. In light of flashattention
and xformers.ops.memory_efficient_attention
, the memory savings due to using sparsity as in SparseCS
are non-existent (because the attention matrix is never materialized). Plus, xformers.components
don't have fp16 / bf16 support, so for any real-world models nowadays they won't be relevant.
❓ Questions and Help
Hi there, I tried the simple sparse attention example here and found it not working properly. Once you swap the order of execution of two kernels, the peak memory results will be different. For example:
You will get:
which means there will be no memory reduction using sparse attention.