pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.41k stars 22.73k forks source link

flex attention compile fails #139548

Open clessig opened 3 weeks ago

clessig commented 3 weeks ago

🐛 Describe the bug

The following code generates the compile error below:

import code
import time

import warnings
import numpy as np
import torch
from torch.nn.attention.flex_attention import flex_attention, create_mask, create_block_mask

import astropy_healpix as hp

hlc = 4
num_healpix_cells = 12 * 4**hlc 
print( f'seq_length : {num_healpix_cells}')
num_heads = 8
dim_embed = 64
bs = 4

q = torch.ones( bs, num_heads, num_healpix_cells, dim_embed, dtype=torch.float16, device='cuda')
k = torch.ones( bs, num_heads, num_healpix_cells, dim_embed, dtype=torch.float16, device='cuda')
v = torch.ones( bs, num_heads, num_healpix_cells, dim_embed, dtype=torch.float16, device='cuda')

with warnings.catch_warnings(action="ignore"):
    nbours= hp.neighbours( np.arange(num_healpix_cells), 2**hlc, order='nested').transpose()
# build adjacency matrix (smarter ways to do it ...)
nbours_mat = torch.zeros( (num_healpix_cells,num_healpix_cells), dtype=torch.bool, device='cuda')
for i in range(num_healpix_cells) :
    for j in nbours[i] :
        nbours_mat[i,j] = True if j>=0 else False
hp_adjacency = nbours_mat

# tc_tokens = torch.from_numpy( np.load( 'tc_tokens.npy')).to(torch.float16).to('cuda')
tc_tokens = torch.ones( [204458, 256], dtype=torch.float16, device='cuda')
tcs_lens = torch.from_numpy( np.load( 'tcs_lens.npy')).to(torch.int32).to('cuda')
print( f'tc_tokens = {tc_tokens.shape}')
print( f'tcs_lens = {tcs_lens.shape}')

tc_tokens_cell_idx = torch.cat( [i * torch.ones( l, dtype=torch.int64) 
                                                                for i,l in enumerate(tcs_lens)])
def sparsity_mask( score, b, h, q_idx, kv_idx):
    return hp_adjacency[ tc_tokens_cell_idx[q_idx], tc_tokens_cell_idx[kv_idx] ]

compiled_flex_attention = torch.compile(flex_attention, dynamic=False)

# poor mans head projection
toks = tc_tokens[:,:64].unsqueeze(0).unsqueeze(0)
out = compiled_flex_attention( toks, toks, toks, score_mod=sparsity_mask)

Required input can be found here: https://cloud.ovgu.de/s/355z3P6ySK4WorB (tcs_lens.npy)

Error logs

Traceback (most recent call last): File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in call compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/init.py", line 2234, in call return compilefx(model, inputs_, config_patches=self.config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx return aot_autograd( ^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in call cg = aot_module_simplified(gm, example_inputs, self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified compiled_fn = dispatch_and_compile() ^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile compiledfn, = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function compiled_fn, fw_metadata = compiler_fn( ^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base compiled_fw = compiler(fw_module, updated_flat_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base return _fw_compiler_base(model, example_inputs, is_inference) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base return inner_compile( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner compiled_graph = FxGraphCache.load( ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1370, in load compiled_graph = compile_fx_fn( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile compiled_graph = fx_codegen_and_compile(gm, example_inputs, fx_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile graph.run(example_inputs) File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run return super().run(args) ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function raise LoweringException(e, target, args, kwargs).with_traceback( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function out = lowerings[target](*args, *kwargs) # type: ignore[index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped out = decomp_fn(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 849, in flex_attention flex_attention_template.maybe_append_choice( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 2158, in maybe_append_choice choices.append(self.generate(kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 676, in generate template = kernel.render(self.template, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 484, in render template.render(self.template_env(), kwargs), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 1304, in render self.environment.handle_exception() File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 939, in handle_exception raise rewrite_traceback_stack(source=source) File "