facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.41k stars 597 forks source link

Sparse attention will not reduce peak memory usage #1086

Closed ThisisBillhe closed 1 month ago

ThisisBillhe commented 1 month ago

❓ Questions and Help

Hi there, I tried the simple sparse attention example here and found it not working properly. Once you swap the order of execution of two kernels, the peak memory results will be different. For example:

from xformers.components.attention import ScaledDotProduct

attention = ScaledDotProduct().cuda()

# FW a random bunch of data
inputs = torch.rand((16, 1024, 1024), device=torch.device("cuda"))

# Now use a very sparse mask and observe that memory use changes
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

mask = (torch.rand((1024, 1024)) < 0.1).cuda()
att = attention(q=inputs, k=inputs, v=inputs, mask=mask)

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Sparse - Peak memory use: {max_memory}MB")

# Not a very sparse mask to begin with
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

mask = (torch.rand((1024, 1024)) < 0.9).cuda()
att = attention(q=inputs, k=inputs, v=inputs, mask=mask)

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Dense - Peak memory use: {max_memory}MB")

# Now use a very sparse mask and observe that memory use changes
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

mask = (torch.rand((1024, 1024)) < 0.1).cuda()
att = attention(q=inputs, k=inputs, v=inputs, mask=mask)

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Sparse - Peak memory use: {max_memory}MB")

You will get:

Sparse - Peak memory use: 509MB
Dense - Peak memory use: 329MB
Sparse - Peak memory use: 329MB

which means there will be no memory reduction using sparse attention.

danthe3rd commented 1 month ago

cc @fmassa

fmassa commented 1 month ago

Hi,

There seems to be an issue with the tutorial indeed. Your example has two issues:

By fixing those two issues, you get expected results:

Sparse - Peak memory use: 142MB
Dense - Peak memory use: 323MB
Sparse - Peak memory use: 577MB
Fixed code in here ``` import torch from xformers.components.attention import ScaledDotProduct from xformers.components.attention.core import SparseCS attention = ScaledDotProduct().cuda() # FW a random bunch of data inputs = torch.rand((16, 1024, 1024), device=torch.device("cuda")) # Now use a very sparse mask and observe that memory use changes torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() mask = (torch.rand((1024, 1024)) < 0.1).cuda() mask = SparseCS(mask, torch.device("cuda")) att = attention(q=inputs, k=inputs, v=inputs, att_mask=mask) torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated() // 2 ** 20 print(f"Sparse - Peak memory use: {max_memory}MB") # Not a very sparse mask to begin with torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() mask = (torch.rand((1024, 1024)) < 0.9).cuda() mask = SparseCS(mask, torch.device("cuda")) att = attention(q=inputs, k=inputs, v=inputs, att_mask=mask) torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated() // 2 ** 20 print(f"Dense - Peak memory use: {max_memory}MB") # Now use a very sparse mask and observe that memory use changes torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() mask = (torch.rand((1024, 1024)) < 0.1).cuda() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() att = attention(q=inputs, k=inputs, v=inputs, att_mask=mask) torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated() // 2 ** 20 print(f"Sparse - Peak memory use: {max_memory}MB") ```

Relevance of xformers.components

That being said, xformers.components (including the code you are using) are deprecated in favor of xformers.ops.memory_efficient_attention and will be soon removed. In light of flashattention and xformers.ops.memory_efficient_attention, the memory savings due to using sparsity as in SparseCS are non-existent (because the attention matrix is never materialized). Plus, xformers.components don't have fp16 / bf16 support, so for any real-world models nowadays they won't be relevant.