pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

Weird benchmarking results: FlexAttention vs SDPA #79

Closed ViktorooReps closed 5 days ago

ViktorooReps commented 6 days ago

Hi! I am experimenting with a char-level LLM with word masking. For now, I am struggling to make flex attention work past compilation errors even for simple causal masking..

What am I benchmarking:

  1. Generation speed
  2. Batch processing (with or without backwards pass)

The model:

What I am expecting to see:

What I actually see:

Can you help me? How do I fix it? Am I not compiling/calling FlexAttention correctly? Thank you so much in advance!

Setup

H100 GPU, CUDA 12.4

pip freeze:

...
accelerate==0.34.2
...
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-ml-py==12.535.161
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
...
torch==2.6.0.dev20241114+cu124
torchaudio==2.5.0.dev20241114+cu124
torchvision==0.20.0.dev20241114+cu124
...
triton==3.1.0
...

Here is the full script to reproduce:

import abc
import json
import time
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path

import nltk
import torch
import re

from nltk.tokenize import PunktTokenizer
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.attention.flex_attention import create_block_mask, and_masks, flex_attention
from torch.nn.functional import scaled_dot_product_attention

from typing import Iterable, TypeVar
from itertools import chain

from tqdm.auto import tqdm

nltk.download('punkt')
nltk.download('punkt_tab')

import torch._dynamo
torch._dynamo.config.suppress_errors = True

_T = TypeVar('_T')

DATASET_PATH = Path('tinyshakespeare.txt')
PRETOKENIZED_DATASET_PATH = Path('data/tinyshakespeare')
DATASET_SOURCE = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

sdpa = scaled_dot_product_attention 

class Configurable(metaclass=abc.ABCMeta):
    @abstractmethod
    def to_config(self):
        pass

    @classmethod
    def from_config(cls: _T, config: dict) -> _T:
        return cls(**config)

    def save(self, path: str | Path) -> None:
        config = self.to_config()
        with open(path, 'w') as f:
            json.dump(config, f, indent=2)

    @classmethod
    def load(cls: _T, path: str | Path) -> _T:
        with open(path, 'r') as f:
            config = json.load(f)

        return cls.from_config(config)

@dataclass
class ModelOutput:
    # model predicts next character and whether the current segment has ended
    logits: torch.Tensor
    logits_segment: torch.Tensor
    # per layer
    keys: list[torch.Tensor]
    values: list[torch.Tensor]

class Generator(nn.Module, metaclass=abc.ABCMeta):
    @abstractmethod
    def forward(
            self,
            tokens: torch.Tensor,
            segment_mask: torch.Tensor,
            specials_mask: torch.Tensor,
            past_keys: list[torch.Tensor] | None = None,
            past_values: list[torch.Tensor] | None = None,
            **kwargs
    ) -> ModelOutput:
        pass

IS_SPECIAL_REGEX = re.compile(r'[^a-zA-Z0-9]')

class Tokenizer(Configurable):
    def __init__(
            self,
            vocab_size: int = 128,
            device: str = 'cpu',
            segment: str | None = None,  # "sentence", "word" or "sentence+word"
    ):
        self.size = vocab_size
        self.segment = segment if segment is not None else ''

        self.is_special = torch.tensor([
            IS_SPECIAL_REGEX.match(f'{chr(o)}') is not None
            for o in range(vocab_size)
        ], dtype=torch.bool, device=device)

        self.tokenizer_sent = None

        self.segment_sent = False
        self.segment_word = False

        self.unk_token = 0
        self.bot_token = 2
        self.eot_token = 3
        self.sep_token = ord(' ')
        self.pad_token = 127

        for segment_type in self.segment.split('+'):
            if segment_type == 'sentence':
                self.tokenizer_sent = PunktTokenizer('english')
                self.segment_sent = True
            if segment_type == 'word':
                self.segment_word = True

    @property
    def device(self):
        return self.is_special.device

    def to(self, device: str):
        self.is_special = self.is_special.to(device)
        return self

    def to_config(self):
        return {
            'size': self.size,
            'segment': self.segment,
        }

    def encode(self, text: str, *, add_sink: bool = True) -> dict:
        if self.segment_sent:
            segments = self.tokenizer_sent.span_tokenize(text)
        else:
            segments = [(0, len(text))]

        tokens = list(map(ord, text))

        token_segments = []
        segment_mask = []
        segment_starts = []

        prev_segm = None
        for segm in segments:
            start, end = segm

            if prev_segm is None:
                if add_sink:
                    # add sink
                    token_segments.append([self.bot_token])
                    segment_mask.append([False])
            elif prev_segm[1] == start:
                # separate the sentences
                token_segments.append([self.sep_token])
                segment_mask.append([False])
            else:
                prev_start, prev_end = prev_segm

                # everything between the sentences
                token_segments.append(tokens[prev_end:start])
                segment_mask.append([False] * (start - prev_end))

            token_segments.append(tokens[start:end])
            segment_mask.append([True] * (end - start))
            segment_starts.append(start)
            prev_segm = segm

        final_s, final_e = prev_segm
        if final_e != len(tokens):
            token_segments.append(tokens[final_e:len(tokens)])
            segment_mask.append([False] * (len(tokens) - final_e))

        all_tokens = chain.from_iterable(token_segments)
        all_segment_mask = chain.from_iterable(segment_mask)

        tokens_torch = torch.tensor(list(all_tokens), device=self.device, dtype=torch.int)
        tokens_torch[tokens_torch >= self.size] = self.unk_token

        return {
            'tokens': tokens_torch,
            'segment_mask': torch.tensor(list(all_segment_mask), device=self.device, dtype=torch.bool),
            'specials_mask': self.is_special[tokens_torch],
            'segment_starts': torch.tensor(segment_starts, device=self.device, dtype=torch.int),
        }

    def decode(self, tokens: Iterable[int]) -> str:
        # skip sink and pad tokens
        return ''.join(chr(o) for o in tokens if o != self.bot_token)

class Collator:
    def __init__(self, padding_token: int, move_to: str = 'cpu'):
        self.padding_token = padding_token
        self.move_to = move_to

    def __call__(self, batch):
        tokens = tuple(item['tokens'] for item in batch)
        tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=self.padding_token)

        segment_mask = tuple(item['segment_mask'] for item in batch)
        segment_mask_padded = pad_sequence(segment_mask, batch_first=True, padding_value=False)

        specials_mask = tuple(item['specials_mask'] for item in batch)
        specials_mask_padded = pad_sequence(specials_mask, batch_first=True, padding_value=False)

        return {
            'tokens': tokens_padded.to(self.move_to),
            'segment_mask': segment_mask_padded.to(self.move_to),
            'specials_mask': specials_mask_padded.to(self.move_to),
        }

def create_mask(
        tokens: torch.Tensor,
        segment_mask: torch.Tensor,
        specials_mask: torch.Tensor,
        *,
        padding_token: int,
        cache_size: int = 0,
        mask_words_distance: int = 0,
        mask_segments_distance: int = 0,
        device: str | None = None
):
    if device is None:
        device = tokens.device

    batch_size, seq_length = tokens.shape

    def causal_mask(b, h, q_idx, kv_idx):
        return (q_idx + cache_size) >= kv_idx

    return create_block_mask(
        mask_mod=and_masks(causal_mask),#, padding_mask, word_mask, segment_mask),
        B=batch_size,
        H=None,
        Q_LEN=seq_length - cache_size,
        KV_LEN=seq_length,
        device=device,
        _compile=True,
        BLOCK_SIZE=128,
    )

@torch.no_grad()
def generate(
        seed: str,
        model: Generator,
        tokenizer: Tokenizer,
        *,
        device: str,
        max_length: int,
        amp_enabled: bool = True,
        progress_bar: bool = True,
) -> str:
    model.eval()

    model = model.to(device)
    tokenizer = tokenizer.to(device)

    inputs = tokenizer.encode(seed)
    tokens = inputs['tokens']
    segment_mask = inputs['segment_mask']
    specials_mask = inputs['specials_mask']

    past_keys = None
    past_values = None

    for _ in tqdm(
            range(len(tokens), max_length),
            total=max_length - len(tokens),
            desc="Generating",
            disable=not progress_bar
    ):
        with torch.amp.autocast(enabled=amp_enabled, device_type=device):
            outputs = model(
                # batch_size = 1
                tokens=tokens.unsqueeze(0),
                segment_mask=segment_mask.unsqueeze(0),
                specials_mask=specials_mask.unsqueeze(0),
                past_keys=past_keys,
                past_values=past_values,
                output_past=True
            )

        # update current state of masks and tokens

        predicted_token = torch.argmax(outputs.logits[0, -1])
        predicted_segment = torch.argmax(outputs.logits_segment[0, -1]).to(dtype=torch.bool)
        is_predicted_special = tokenizer.is_special[predicted_token.item()]

        tokens = torch.concatenate([tokens, predicted_token.view(1)], dim=0)
        segment_mask = torch.concatenate([segment_mask, predicted_segment.view(1)], dim=0)
        specials_mask = torch.concatenate([specials_mask, is_predicted_special.view(1)], dim=0)

        # update past keys and values

        past_keys = outputs.keys
        past_values = outputs.values

    return tokenizer.decode(tokens)

class DummyModel(Generator):
    def __init__(
            self,
            vocab_size: int = 128,
            n_layers: int = 36,
            head_dim: int = 128,
            n_heads: int = 64,
            pad_token: int = 127,
            device: str = 'cpu',
            attn_impl: str = 'flex',
    ):
        nn.Module.__init__(self)

        self.n_layers = n_layers
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.hidden_dim = self.n_heads * self.head_dim
        self.vocab_size = vocab_size
        self.pad_token = pad_token
        self.attn_impl = attn_impl

        self.emb = nn.Embedding(self.vocab_size, self.hidden_dim, device=device, dtype=torch.bfloat16)
        self.lm_head = nn.Linear(self.hidden_dim, self.vocab_size, device=device, dtype=torch.bfloat16)
        self.segment_head = nn.Linear(self.hidden_dim, 2, device=device, dtype=torch.bfloat16)

    def forward(
            self,
            tokens: torch.Tensor,
            segment_mask: torch.Tensor,
            specials_mask: torch.Tensor,
            past_keys: list[torch.Tensor] | None = None,
            past_values: list[torch.Tensor] | None = None,
            output_past: bool = False,
            **_,
    ):
        batch_size, seq_length = tokens.shape
        device = tokens.device
        assert device == self.emb.weight.device

        no_cache = (past_keys is None or not len(past_keys))
        cache_size = past_keys[0].shape[1] if not no_cache else 0

        hidden = self.emb(tokens[:, cache_size:])

        if no_cache:
            past_keys = [hidden.new_empty((batch_size, 0, self.hidden_dim)) for _ in range(self.n_layers)]
            past_values = [hidden.new_empty((batch_size, 0, self.hidden_dim)) for _ in range(self.n_layers)]

        new_len = seq_length - cache_size

        if self.attn_impl == 'flex':
            mask = create_mask(
                tokens, segment_mask, specials_mask,
                cache_size=cache_size,
                padding_token=self.pad_token
            )
        else:
            mask = None  # TODO

        for layer_idx in range(self.n_layers):
            hidden_k = torch.concatenate([hidden, past_keys[layer_idx]], dim=1)
            hidden_v = torch.concatenate([hidden, past_values[layer_idx]], dim=1)

            if output_past:
                # update cache
                past_keys[layer_idx] = hidden_k.detach()
                past_values[layer_idx] = hidden_v.detach()

            # calculate new hidden for next layer

            # (B, S, H) -> (B, kh, S, Hh)
            hidden = hidden.view(batch_size, new_len, self.n_heads, self.head_dim).transpose(1, 2)
            hidden_k = hidden_k.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
            hidden_v = hidden_v.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)

            if self.attn_impl == 'flex':
                hidden = flex_attention(hidden, hidden_k, hidden_v, block_mask=mask)
            elif self.attn_impl == 'sdpa':
                hidden = sdpa(hidden, hidden_k, hidden_v, is_causal=True)
            hidden = hidden.transpose(1, 2).contiguous().view(batch_size, new_len, -1)

        return ModelOutput(
            # project from hidden space to logit space
            logits=self.lm_head(hidden),
            logits_segment=self.segment_head(hidden),
            keys=past_keys if output_past else None,
            values=past_values if output_past else None
        )

class Benchmark:
    def __enter__(self):
        self.start_time = time.time()
        torch.cuda.reset_peak_memory_stats(device=None)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated(device=None) / (1024 ** 2)  # Convert bytes to MB
        time_taken = (end_time - self.start_time) * 1000  # Convert seconds to milliseconds

        print(f"Time taken: {time_taken:.2f} ms")
        print(f"Peak memory used: {peak_memory:.2f} MB")

if __name__ == '__main__':
    impl = 'sdpa'  # or flex

    model = torch.compile(DummyModel(device='cuda', attn_impl=impl), dynamic=True)
    tokenizer = Tokenizer(device='cuda', segment='word+sentence')

    model.eval()

    print('\nTesting model...')

    print('Compilation warmup')

    test_str = 'O God, O God! ' * 100
    res = generate(test_str, model, tokenizer, max_length=len(test_str) + 100, device='cuda', amp_enabled=True)

    print('1000 tokens generation:')
    with Benchmark():
        generate(test_str, model, tokenizer, max_length=len(test_str) + 1000, device='cuda', amp_enabled=True)
    print()

    s = 'What is love??? '
    assert len(s) == 16
    batch_len = (128 // 16) * 20
    batch_size = 16

    print(f'{batch_size}x{batch_len * 16} batch processing (no backward):')
    with Benchmark():
        with torch.no_grad():
            model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
                [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
            ))
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (no backward):')
    with Benchmark():
        with torch.no_grad():
            model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
                [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
            ))
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (no backward):')
    with Benchmark():
        with torch.no_grad():
            model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
                [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
            ))
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (no backward):')
    with Benchmark():
        with torch.no_grad():
            model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
                [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
            ))
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (with backward):')
    with Benchmark():
        res = model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
            [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
        ))
        torch.sum(res.logits ** 2).backward()
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (with backward):')
    with Benchmark():
        res = model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
            [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
        ))
        torch.sum(res.logits ** 2).backward()
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (with backward):')
    with Benchmark():
        res = model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
            [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
        ))
        torch.sum(res.logits ** 2).backward()
    print()

    print(f'{batch_size}x{batch_len * 16} batch processing (with backward):')
    with Benchmark():
        res = model(**Collator(padding_token=tokenizer.pad_token, move_to='cuda')(
            [tokenizer.encode(s * batch_len, add_sink=False)] * batch_size
        ))
        torch.sum(res.logits ** 2).backward()
    print()

When I run it with impl = 'sdpa', I get no compilation errors, and here are the results:

Testing model...
Compilation warmup
Generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:26<00:00,  3.76it/s]
1000 tokens generation:
Generating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [00:02<00:00, 333.34it/s]
Time taken: 3057.84 ms
Peak memory used: 5533.08 MB

16x2560 batch processing (no backward):
Time taken: 5929.39 ms
Peak memory used: 2629.27 MB

16x2560 batch processing (no backward):
Time taken: 3.93 ms
Peak memory used: 2629.27 MB

16x2560 batch processing (no backward):
Time taken: 204.47 ms
Peak memory used: 2629.27 MB

16x2560 batch processing (no backward):
Time taken: 203.93 ms
Peak memory used: 2629.27 MB

16x2560 batch processing (with backward):
Time taken: 10406.60 ms
Peak memory used: 28717.68 MB

16x2560 batch processing (with backward):
Time taken: 1472.95 ms
Peak memory used: 28721.72 MB

16x2560 batch processing (with backward):
Time taken: 6.67 ms
Peak memory used: 28721.72 MB

16x2560 batch processing (with backward):
Time taken: 840.58 ms
Peak memory used: 28721.72 MB

Here is the output for impl = flex:

Testing model...
Compilation warmup
Generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:28<00:00,  3.48it/s]
1000 tokens generation:
Generating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [00:03<00:00, 265.92it/s]
Time taken: 3817.95 ms
Peak memory used: 5511.12 MB

16x2560 batch processing (no backward):
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] WON'T CONVERT torch_dynamo_resume_in_forward_at_514 /mloscratch/homes/shcherba/dl-char-llm/reproduce.py line 514 
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] due to: 
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] Traceback (most recent call last):
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   File "/mloscratch/homes/shcherba/conda/envs/char-llm/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1230, in __call__
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     result = self._inner_convert(
...
(removed rows)
...
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   target: flex_attention
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[0]: TensorBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     ReinterpretView(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]         InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.bfloat16, size=[s1, s12, 8192], stride=[8192*s12, 8192, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       ),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       FixedLayout('cuda', torch.bfloat16, size=[s1, 64, s12, 128], stride=[8192*s12, 128, 8192, 1]),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       origins=OrderedSet([permute])
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[1]: TensorBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     ReinterpretView(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]         InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.bfloat16, size=[s1, s12, 8192], stride=[8192*s12, 8192, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       ),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       FixedLayout('cuda', torch.bfloat16, size=[s1, 64, s12, 128], stride=[8192*s12, 128, 8192, 1]),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       origins=OrderedSet([permute_1])
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[2]: TensorBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     ReinterpretView(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]         InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.bfloat16, size=[s1, s12, 8192], stride=[8192*s12, 8192, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       ),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       FixedLayout('cuda', torch.bfloat16, size=[s1, 64, s12, 128], stride=[8192*s12, 128, 8192, 1]),
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]       origins=OrderedSet([permute_2])
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[4]: (TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg14_1', layout=FixedLayout('cuda', torch.int32, size=[s13, 1, s14], stride=[s14, s14, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg19_1', layout=FixedLayout('cuda', torch.int32, size=[s16, 1, s17, s18], stride=[s17*s18, s17*s18, s18, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg22_1', layout=FixedLayout('cuda', torch.int32, size=[s19, 1, s20], stride=[s20, s20, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg26_1', layout=FixedLayout('cuda', torch.int32, size=[s21, 1, s22, s23], stride=[s22*s23, s22*s23, s23, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg29_1', layout=FixedLayout('cuda', torch.int32, size=[s24, 1, s25], stride=[s25, s25, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg33_1', layout=FixedLayout('cuda', torch.int32, size=[s26, 1, s27, s28], stride=[s27*s28, s27*s28, s28, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg36_1', layout=FixedLayout('cuda', torch.int32, size=[s29, 1, s30], stride=[s30, s30, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), TensorBox(StorageBox(
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]     InputBuffer(name='arg40_1', layout=FixedLayout('cuda', torch.int32, size=[s31, 1, s32, s33], stride=[s32*s33, s32*s33, s33, 1]))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   )), s34, s35, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[5]: 0.08838834764831843
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': False}
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[7]: ()
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299]   args[8]: (s15,)
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] 
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
W1116 16:48:52.545000 55202 site-packages/torch/_dynamo/convert_frame.py:1299] 
Time taken: 12754.62 ms
Peak memory used: 2607.39 MB

16x2560 batch processing (no backward):
Time taken: 11.01 ms
Peak memory used: 2607.39 MB

16x2560 batch processing (no backward):
Time taken: 318.14 ms
Peak memory used: 2607.39 MB

16x2560 batch processing (no backward):
Time taken: 318.26 ms
Peak memory used: 2607.39 MB

16x2560 batch processing (with backward):
Time taken: 5393.32 ms
Peak memory used: 73441.23 MB

16x2560 batch processing (with backward):
Time taken: 765.57 ms
Peak memory used: 73443.23 MB

16x2560 batch processing (with backward):
Time taken: 1088.62 ms
Peak memory used: 73443.23 MB

16x2560 batch processing (with backward):
Time taken: 1082.83 ms
Peak memory used: 73443.23 MB
ViktorooReps commented 5 days ago

Solved with removing _compile=True argument from mask creation and 'dynamic=True' from torch.compile. I don't really see performance improvements in speed (esp in generation, I suspect because of mask recomputations), but I see some minor memory usage improvements which is the main goal of using FlexAttention for me.