pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
380 stars 19 forks source link

CUDA OOM with sliding window attention #15

Closed mishooax closed 1 month ago

mishooax commented 1 month ago

Hi - I'm trying to implement sliding window attention with flex-attention, as described in the snippet below, inspired by the sliding window attn example in the attention-gym. Note I use a rather long sequence length (40320), but a small sliding window (512).

import torch

from torch.nn.attention.flex_attention import _mask_mod_signature
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def generate_sliding_window(window_size: int) -> _mask_mod_signature:
    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

if __name__ == "__main__":

    B, H, SEQ_LEN, HEAD_DIM = 1, 16, 40320, 2
    WINDOW_SIZE = 512

    def make_tensor():
        return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.float16)

    q, k, v = make_tensor(), make_tensor(), make_tensor()

    sliding_window_mask = generate_sliding_window(window_size=WINDOW_SIZE)

    block_mask = create_block_mask(
        sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=False
    )
    out = flex_attention(q, k, v, block_mask=block_mask)
    print(f"Shape of output tensor: {list(out.shape)}")

Unfortunately, this OOMs:

  File "python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 138, in _math_attention_inner
    scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
              ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 96.90 GiB. GPU 0 has a total capacity of 39.56 GiB of which 26.49 GiB is free. Including non-PyTorch memory, this process has 13.06 GiB memory in use. Of the allocated memory 790.52 MiB is allocated by PyTorch, and 11.81 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Looks like flex-attention is trying to create a SEQ_LEN x SEQ_LEN matrix at the Q * K^T step. What am I doing wrong here? Flash attention's flash_attn_func can handle this sequence len just fine, as long as i pass in the correct window_size. Thank you!

drisspg commented 1 month ago

Ah, yeah, we need to better document this. There are two things going on here:

block_mask = create_block_mask(
    sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=False
)

You need to call this with _compile=True. We essentially map your block mask over a full Q_LEN x KV_LEN matrix in order to produce the block mask. Without compile, we need to materialize this full thing, and it can cause OOMs on long sequences.

As well, you need to run flex_attention = torch.compile(flex_attention). Without compile, flex falls back to a non-fused eager implementation that is great for debugging, but it is much slower and materializes the full scores matrix.

mishooax commented 1 month ago

thanks @drisspg ! this seems to work:

from functools import partial
import torch

from torch.nn.attention.flex_attention import _mask_mod_signature
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def generate_sliding_window(window_size: int) -> _mask_mod_signature:

    def sliding_window(b, h, q_idx, kv_idx):
        del b, h # not used
        return torch.abs(q_idx - kv_idx) <= window_size // 2

    sliding_window_mask = sliding_window
    sliding_window_mask.__name__ = f"sliding_window_{window_size}"
    return sliding_window_mask

if __name__ == "__main__":
    B, H, SEQ_LEN, HEAD_DIM = 1, 16, 40320, 32
    WINDOW_SIZE = 512

    def make_tensor():
        return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.float16)

    q, k, v = make_tensor(), make_tensor(), make_tensor()

    sliding_window_mask = generate_sliding_window(window_size=WINDOW_SIZE)

    block_mask = create_block_mask(
        sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=True
    )
    opt_flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask))
    out = opt_flex_attention(q, k, v, block_mask=block_mask)
    print(f"Shape of output tensor: {list(out.shape)}")

FWIW, I got shape-related compilation errors with my initial tensor sizes:

AssertionError: All non-batch values in both first input shape ([constexpr[128], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 16!

The above exception was the direct cause of the following exception:

triton.compiler.errors.CompilationError: at 44:13:

    RCP_LN2: tl.constexpr = 1.44269504

    if PRESCALE_QK:
        q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

    # loop over k, v and update accumulator
    for start_n in range(block_n_start, block_n_end):
        # -- load k --
        k = tl.load(K_block_ptr)
        # -- compute qk ---
        qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2.

Is this a known limitation of flex-attention (or the torch compiler)?

drisspg commented 1 month ago

Yeah I think this might be fixed in later triton version than we are currently shipping in PyTorch today, let me create a a repo though for tracking

foreverpiano commented 1 month ago

https://github.com/pytorch/pytorch/issues/133321 @drisspg

drisspg commented 1 month ago

@foreverpiano I am going to close this and we can use the pytorch issue for tracking

fteufel commented 1 month ago

Hi @drisspg ,

are there any pointers for how we need to code this to make it work inside a nn.Module?

Using

    sliding_window_mask = generate_sliding_window(window_size=WINDOW_SIZE) # defined this outside of nn.Module

    block_mask = create_block_mask(
        sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=True
    )

inside forward() throws compilation errors when i run torch.compile(module)

File "/home/mytransformer.py", line 226, in torch_dynamo_resume_in_forward_at_211
    sliding_window = generate_sliding_window(self.window_size)
  File "/home/mytransformer.py", line 232, in torch_dynamo_resume_in_forward_at_226
    sliding_window_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=q.shape[1], KV_LEN=kv_seq_len, device=device, _compile=True)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 800, in create_block_mask
    inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 802, in torch_dynamo_resume_in_create_block_mask_at_800
    block_mask = inner_func(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
    return fn(*args, **kwargs)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
    super().run()
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
    while self.step():
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2922, in RETURN_VALUE
    self._return(inst)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2907, in _return
    self.output.compile_subgraph(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1134, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1361, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1408, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1457, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:

Traceback (most recent call last):
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
    result = job()
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
    load_kernel().precompile(warm_cache_only=True)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 234, in precompile
    compiled_binary, launcher = self._precompile_config(
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 418, in _precompile_config
    triton.compile(*compile_args, **compile_kwargs),
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
  File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Thanks!