Closed mishooax closed 1 month ago
Ah, yeah, we need to better document this. There are two things going on here:
block_mask = create_block_mask(
sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=False
)
You need to call this with _compile=True. We essentially map your block mask over a full Q_LEN x KV_LEN matrix in order to produce the block mask. Without compile, we need to materialize this full thing, and it can cause OOMs on long sequences.
As well, you need to run flex_attention = torch.compile(flex_attention)
. Without compile, flex falls back to a non-fused eager implementation that is great for debugging, but it is much slower and materializes the full scores matrix.
thanks @drisspg ! this seems to work:
from functools import partial
import torch
from torch.nn.attention.flex_attention import _mask_mod_signature
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
def generate_sliding_window(window_size: int) -> _mask_mod_signature:
def sliding_window(b, h, q_idx, kv_idx):
del b, h # not used
return torch.abs(q_idx - kv_idx) <= window_size // 2
sliding_window_mask = sliding_window
sliding_window_mask.__name__ = f"sliding_window_{window_size}"
return sliding_window_mask
if __name__ == "__main__":
B, H, SEQ_LEN, HEAD_DIM = 1, 16, 40320, 32
WINDOW_SIZE = 512
def make_tensor():
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.float16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
sliding_window_mask = generate_sliding_window(window_size=WINDOW_SIZE)
block_mask = create_block_mask(
sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=True
)
opt_flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask))
out = opt_flex_attention(q, k, v, block_mask=block_mask)
print(f"Shape of output tensor: {list(out.shape)}")
FWIW, I got shape-related compilation errors with my initial tensor sizes:
AssertionError: All non-batch values in both first input shape ([constexpr[128], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 16!
The above exception was the direct cause of the following exception:
triton.compiler.errors.CompilationError: at 44:13:
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator
for start_n in range(block_n_start, block_n_end):
# -- load k --
k = tl.load(K_block_ptr)
# -- compute qk ---
qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2.
Is this a known limitation of flex-attention (or the torch compiler)?
Yeah I think this might be fixed in later triton version than we are currently shipping in PyTorch today, let me create a a repo though for tracking
@foreverpiano I am going to close this and we can use the pytorch issue for tracking
Hi @drisspg ,
are there any pointers for how we need to code this to make it work inside a nn.Module
?
Using
sliding_window_mask = generate_sliding_window(window_size=WINDOW_SIZE) # defined this outside of nn.Module
block_mask = create_block_mask(
sliding_window_mask, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=True
)
inside forward()
throws compilation errors when i run torch.compile(module)
File "/home/mytransformer.py", line 226, in torch_dynamo_resume_in_forward_at_211
sliding_window = generate_sliding_window(self.window_size)
File "/home/mytransformer.py", line 232, in torch_dynamo_resume_in_forward_at_226
sliding_window_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=q.shape[1], KV_LEN=kv_seq_len, device=device, _compile=True)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 800, in create_block_mask
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 802, in torch_dynamo_resume_in_create_block_mask_at_800
block_mask = inner_func(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
return self._torchdynamo_orig_callable(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
return _compile(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
return fn(*args, **kwargs)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
tracer.run()
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
super().run()
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
while self.step():
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2922, in RETURN_VALUE
self._return(inst)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2907, in _return
self.output.compile_subgraph(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1134, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1361, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1408, in call_user_compiler
return self._call_user_compiler(gm)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1457, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:
Traceback (most recent call last):
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
result = job()
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
load_kernel().precompile(warm_cache_only=True)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 234, in precompile
compiled_binary, launcher = self._precompile_config(
File "/home/.conda/envs/debug/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 418, in _precompile_config
triton.compile(*compile_args, **compile_kwargs),
File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/compiler/compiler.py", line 282, in compile
next_module = compile_ir(module, metadata)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
File "/home/.conda/envs/debug/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Thanks!
Hi - I'm trying to implement sliding window attention with flex-attention, as described in the snippet below, inspired by the sliding window attn example in the attention-gym. Note I use a rather long sequence length (40320), but a small sliding window (512).
Unfortunately, this OOMs:
Looks like flex-attention is trying to create a
SEQ_LEN x SEQ_LEN
matrix at theQ * K^T
step. What am I doing wrong here? Flash attention'sflash_attn_func
can handle this sequence len just fine, as long as i pass in the correctwindow_size
. Thank you!