lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Should the mask option to AttentionLayers(boolean) increase memory? #261

Closed blasscoc closed 2 months ago

blasscoc commented 2 months ago

Should the mask option in AttentionLayers be using up so much memory?

Behavior:

When I use the mask variant, the memory consumption is 6326MiB, without the mask, the memory consumption is 1364MiB.

6326MiB with the mask, 1364MiB without the mask.

import torch

from x_transformers.x_transformers import AttentionLayers

attn_config = { "dim": 128, "depth": 4, "heads": 6, "ff_mult": 8, "attn_flash": True }

encoder = AttentionLayers(**attn_config) encoder.cuda()

x = torch.randn(10, 4000, 128) mask = torch.ones(10, 4000).bool()

x = x.to('cuda') mask = mask.to('cuda')

while 1: with torch.no_grad():

output = encoder(x, mask=mask)

    output = encoder(x)
blasscoc commented 2 months ago

I followed the memory consumption through to the call to scaled_dot_product_attention in "flash_attn" function. I saw that the memory increases in that function. I verified that the dtype of mask, if initially bool remained so throughout, and that there weren't significant copies being made.

The memory consumption will scale like the sequence length squared, times the number of heads. And so can become quite large. Interestingly casting from bool to float on line 238 of attend.py resulted in only a modest increase in memory consumption, which was unexpected.