ELS-RD / kernl

Kernl lets you run PyTorch transformer models several times faster on GPU with a single line of code, and is designed to be easily hackable.
http://www.kernl.ai
Apache License 2.0
1.53k stars 95 forks source link

Fix crash on T5 inference for some specific shapes #176

Closed pommedeterresautee closed 1 year ago

pommedeterresautee commented 1 year ago

when running tests on main, there is a crash. We had other issues on short seqlen and large batch on t5, not sure why...

❯ pytest test/test_torchdynamo.py -k "dynamo_optimized_cuda_graphs-8x16-t5"
===================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.9.15, pytest-7.2.0, pluggy-1.0.0
rootdir: /mnt/workspace/kernl
plugins: anyio-3.6.1
collected 199 items / 198 deselected / 1 selected                                                                                                                                                                             

test/test_torchdynamo.py F                                                                                                                                                                                              [100%]

========================================================================================================== FAILURES ===========================================================================================================
_________________________________________________________________________ test_benchmark_implementations[dynamo_optimized_cuda_graphs-8x16-t5-small] __________________________________________________________________________

C = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0408f5a860>, ACT_INPUTS = None
A = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0408f5a6d0>
B = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Parameter object at 0x7f03f6968ea0>
bias = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0408f750e0>, M = 128, N = 2048, K = 512, CACHE_KEY_M = 4, CACHE_KEY_N = 64, CACHE_KEY_K = 16, stride_om = 2048
stride_on = 1, stride_im = 512, stride_ik = 1, stride_wn = 512, stride_wk = 1, BLOCK_M = 128, GROUP_M = 8, BLOCK_N = 64, BLOCK_K = 32, SPLIT_K = 1, EVEN_K = True, BIAS = False, SAVE_ACT_INPUTS = False, ACTIVATION = 'relu'
grid = (32,), num_warps = 4, num_stages = 4, extern_libs = None, stream = 2040603040, warmup = False

>   ???
E   KeyError: ('2-.-0-.-0-d6e5675c89b63c389326c8b846421ab2-7929002797455b30efce6e41eddc6b57-3aa563e00c5c695dd945e23b09a86848-f24b6aa9b101a518b6a4a6bddded372e-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, None, torch.float16, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 8, 64, 32, 1, True, False, False, 'relu'), (True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)))

<string>:21: KeyError

During handling of the above exception, another exception occurred:

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f0412c54220>
gm = GraphModule(
  (base_encoder_embed_tokens): Embedding(32128, 512)
  (base_encoder_block_0_layer_0_SelfAttention_q): Li...s=False)
  (base_decoder_block_5_layer__1__DenseReluDense_wo): Linear(in_features=2048, out_features=512, bias=False)
)

    def call_user_compiler(self, gm):
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
>           compiled_fn = self.compiler_fn(gm, self.example_inputs())

/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:433: 
...
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f0412c54220>
gm = GraphModule(
  (base_encoder_embed_tokens): Embedding(32128, 512)
  (base_encoder_block_0_layer_0_SelfAttention_q): Li...s=False)
  (base_decoder_block_5_layer__1__DenseReluDense_wo): Linear(in_features=2048, out_features=512, bias=False)
)

    def call_user_compiler(self, gm):
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiled_fn = self.compiler_fn(gm, self.example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except Exception as e:
            log.warning("-" * 40 + "\n")
            log.warning("TORCHDYNAMO: backend compiler failed\n")
            log.warning(e, exc_info=True)
            log.warning("-" * 40 + "\n")
            compiled_fn = gm.forward
>           raise BackendCompilerFailed(self.compiler_fn, e) from e
E           torch._dynamo.exc.BackendCompilerFailed: compiler raised RuntimeError: CUDA: Error- illegal address
E           
E           You can suppress this exception and fall back to eager by setting:
E               torchdynamo.config.suppress_errors = True

/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:442: BackendCompilerFailed

During handling of the above exception, another exception occurred:

benchmark = <kernl.benchmark.benchmark_fixture.BenchmarkFixture object at 0x7f0413b4d3a0>
reference_fp32 = T5Model(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block)...     )
        )
      )
    )
    (final_layer_norm): T5LayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
shape = (8, 16), implementation = Implementation(name='dynamo_optimized_cuda_graphs', model=<function get_model_optimized_cuda_graphs at 0x7f0414304790>, is_causal=False)

    @setup_dynamo()
    @set_seed()
    @pytest.mark.parametrize(
        "reference_fp32",
        ["bert-base-uncased", "t5-small"],
        indirect=True,
    )
    @pytest.mark.parametrize(
        "shape",
        # TODO add shape 32x32 which may be unstable with T5
        [(bs, seq_l) for bs in [1, 8, 32] for seq_l in [16, 33, 128, 256, 384, 512] if bs * seq_l < 10000],
        ids=lambda x: f"{x[0]}x{x[1]}",
    )
    @pytest.mark.parametrize("implementation", implementations, ids=lambda v: v.name)
    def test_benchmark_implementations(benchmark, reference_fp32, shape: (int, int), implementation: Implementation):
        if "nvfuser" in implementation.name and reference_fp32.config.name_or_path != "bert-base-uncased":
            pytest.skip("Only supported for BERT")

        inputs = get_input(reference_fp32, shape, is_causal=implementation.is_causal)
        with torch.inference_mode():
            expected = reference_fp32(**inputs)
            model = implementation.model(reference_fp32)
            with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
>               value = benchmark(model, **inputs)

test/test_torchdynamo.py:85: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/kernl/benchmark/benchmark_fixture.py:53: in __call__
    function_to_benchmark(*args, **kwargs)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:166: in _fn
    return fn(*args, **kwargs)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:249: in catch_errors
    return callback(frame, cache_size)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:452: in _convert_frame
    result = inner_convert(frame, cache_size)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:123: in _fn
    torch.cuda.set_rng_state(cuda_rng_state)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/random.py:64: in set_rng_state
    _lazy_call(cb)
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/__init__.py:176: in _lazy_call
    callable()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    def cb():
        idx = cast(torch.device, device).index
        if idx is None:
            idx = current_device()
        default_generator = torch.cuda.default_generators[idx]
>       default_generator.set_state(new_state_copy)
E       RuntimeError: false INTERNAL ASSERT FAILED at "../c10/cuda/CUDAGraphsC10Utils.h":73, please report a bug to PyTorch. Unknown CUDA graph CaptureStatus32516

/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/random.py:62: RuntimeError
------------------------------------------------------------------------------------------------------ Captured log call ------------------------------------------------------------------------------------------------------
WARNING  torch._dynamo.output_graph:output_graph.py:437 ----------------------------------------

WARNING  torch._dynamo.output_graph:output_graph.py:438 TORCHDYNAMO: backend compiler failed

WARNING  torch._dynamo.output_graph:output_graph.py:439 CUDA: Error- illegal address
Traceback (most recent call last):
  File "<string>", line 21, in kernel_fma
KeyError: ('2-.-0-.-0-d6e5675c89b63c389326c8b846421ab2-7929002797455b30efce6e41eddc6b57-3aa563e00c5c695dd945e23b09a86848-f24b6aa9b101a518b6a4a6bddded372e-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, None, torch.float16, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 8, 64, 32, 1, True, False, False, 'relu'), (True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 433, in call_user_compiler
    compiled_fn = self.compiler_fn(gm, self.example_inputs())
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 819, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/mnt/workspace/kernl/test/models/bert.py", line 75, in compiler
    return cuda_graphs_wrapper(gm, example_inputs)
  File "/mnt/workspace/kernl/src/kernl/implementations/cuda_graph.py", line 40, in cuda_graphs_wrapper
    model(*inputs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/fx/graph_module.py", line 660, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/fx/graph_module.py", line 279, in __call__
    raise e
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/fx/graph_module.py", line 269, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key>.165", line 64, in forward
    linear_wrapper = kernl_optimizer_linear_linear_wrapper(layer_norm_rms_wrapper_1, linear, activation = 'relu');  layer_norm_rms_wrapper_1 = linear = None
  File "/mnt/workspace/kernl/src/kernl/optimizer/linear.py", line 31, in linear_wrapper
    return linear_layer(v, linear.weight, linear.bias, activation=activation)
  File "/mnt/workspace/kernl/src/kernl/implementations/linear_layer.py", line 284, in linear_layer
    return LinearLayer.apply(x, weight, bias, activation, act_inputs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py", line 104, in decorate_fwd
    return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
  File "/mnt/workspace/kernl/src/kernl/implementations/linear_layer.py", line 248, in forward
    kernel_fma[grid](
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 73, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 73, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 63, in _bench
    return do_bench(kernel_call)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/testing.py", line 140, in do_bench
    fn()
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 62, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in kernel_fma
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/compiler.py", line 1268, in compile
    return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/compiler.py", line 1301, in __init__
    mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
RuntimeError: CUDA: Error- illegal address
WARNING  torch._dynamo.output_graph:output_graph.py:440 ----------------------------------------

ERROR    torch._dynamo.convert_frame:convert_frame.py:252 WON'T CONVERT run /mnt/workspace/kernl/test/models/bert.py line 77 
due to: 
Traceback (most recent call last):
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/triton/compiler.py", line 1301, in __init__
    mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
RuntimeError: CUDA: Error- illegal address

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 442, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compiler raised RuntimeError: CUDA: Error- illegal address

You can suppress this exception and fall back to eager by setting:
    torchdynamo.config.suppress_errors = True

from user code:
   File "/mnt/workspace/kernl/test/models/bert.py", line 79, in run
    return base(*args, **kwargs)

Set torch._dynamo.config.verbose=True for more information
==========
ERROR    torch._dynamo.eval_frame:eval_frame.py:251 Error while processing frame
Traceback (most recent call last):
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 118, in _fn
    return fn(*args, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 87, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 325, in _convert_frame_assert
    return _compile(
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 380, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 368, in transform
    tracer.run()
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1448, in run
    super().run()
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 349, in run
    and self.step()
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 322, in step
    getattr(self, inst.opname)(inst)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1510, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 355, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 401, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 442, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compiler raised RuntimeError: CUDA: Error- illegal address

You can suppress this exception and fall back to eager by setting:
    torchdynamo.config.suppress_errors = True

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 249, in catch_errors
    return callback(frame, cache_size)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 452, in _convert_frame
    result = inner_convert(frame, cache_size)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 123, in _fn
    torch.cuda.set_rng_state(cuda_rng_state)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/random.py", line 64, in set_rng_state
    _lazy_call(cb)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/__init__.py", line 176, in _lazy_call
    callable()
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/cuda/random.py", line 62, in cb
    default_generator.set_state(new_state_copy)
RuntimeError: false INTERNAL ASSERT FAILED at "../c10/cuda/CUDAGraphsC10Utils.h":73, please report a bug to PyTorch. Unknown CUDA graph CaptureStatus32516
=================================================================================================== short test summary info ===================================================================================================
FAILED test/test_torchdynamo.py::test_benchmark_implementations[dynamo_optimized_cuda_graphs-8x16-t5-small] - RuntimeError: false INTERNAL ASSERT FAILED at "../c10/cuda/CUDAGraphsC10Utils.h":73, please report a bug to PyTorch. Unknown CUDA graph CaptureStatus32516
============================================================================================= 1 failed, 198 deselected in 17.24s ==============================================================================================
pommedeterresautee commented 1 year ago

Error says it is related to CUDA addresses and reproduction happens only on cuda graphs only. If mask is hard coded to None, it doesn't happen.

# for 8x16 on t5 small
q.shape=torch.Size([8, 8, 16, 64])
k.shape=torch.Size([8, 8, 16, 64])
v.shape=torch.Size([8, 8, 16, 64])
attention_mask.shape=torch.Size([8, 8, 16, 16])

Attention kernel are not tested with CUDA Graphs for now.

Failing shapes:

All other shapes from unit tests are not crashing on T5. No test fail for bert, including (reintroduced for the occasion) 32x32 shape.

# no CUDA graphs ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized and not cuda_graphs and t5 and (8x16 or 1x32)"
# Bert ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized_cuda_graphs and bert"
# all shapes on t5 but the 2 above are ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized_cuda_graphs and t5 and not 8x16 and not 1x32" 
pommedeterresautee commented 1 year ago

Somewhere in the middle of the error trace there is:

heads = 8, size_m = 32, size_n = 32, size_m_rounded = 128, Q = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1615b040>
K = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa512c0>
V = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa572c0>, sm_scale = 1.0
attention_mask = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c162645e0>
TMP = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c162644a0>
output = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa57d60>, q_batch_stride = 16384, q_head_stride = 64, q_m_stride = 512, q_k_stride = 1, k_batch_stride = 16384
k_head_stride = 64, k_n_stride = 512, k_k_stride = 1, v_batch_stride = 16384, v_head_stride = 64, v_k_stride = 512, v_n_stride = 1, o_batch_stride = 16384, o_head_stride = 64, o_m_stride = 512, o_n_stride = 1
attention_mask_batch_stride = 8, attention_mask_head_stride = 1, attention_mask_m_stride = 256, attention_mask_k_stride = 8, min_clamp_value = -65504.0, NEED_LOAD_MASK_SIZE_N = True, NEED_LOAD_MASK_SIZE_M = True
MASK_BATCH_SIZE = 1, MASK_HEAD_SIZE = 8, MASK_M_SIZE = 32, MASK_K_SIZE = 32, HAS_MASK = True, IS_CAUSAL = False, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 8), num_warps = 4, num_stages = 1
extern_libs = None, stream = 2050343696, warmup = False

Printing values shows that:

q.shape=torch.Size([1, 8, 32, 64]), k.shape=torch.Size([1, 8, 32, 64]), v.shape=torch.Size([1, 8, 32, 64]), output.shape=torch.Size([1, 8, 32, 64])
q.stride()=(16384, 64, 512, 1), k.stride()=(16384, 64, 512, 1), v.stride()=(16384, 64, 512, 1), output.stride()=(16384, 64, 512, 1)
q.dtype=torch.float16, k.dtype=torch.float16, v.dtype=torch.float16, output.dtype=torch.float16
attention_mask.shape=torch.Size([1, 8, 32, 32])
attention_mask.stride()=(8, 1, 256, 8)

If we comment the following line, code do not crash but output is wrong:

# reminder: attention_mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0,
offs_mask += offs_m[:, None] * attention_mask_m_stride