I modified the flash attention code to make it work for cases where q has sequence length of 1, while k, v has larger length. This is critical for text generation at run time where one token is generated at a time.
Since tl.dot does not support matrix/vector dot product. And matrix/vector product cannot utilize TensorCore anyway, so I replace it with broadcasting + tl.sum. But I run into the following error. It isn't very informative for me. Can people help?
Traceback (most recent call last):
File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 163, in <module>
test_op(3, 2, 2048, 64, 1)
File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 158, in test_op
tri_out = attention(q, k, v, sm_scale)
File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 125, in attention
_fwd_kernel[grid](
File "/workspace/triton/python/triton/code_gen.py", line 999, in __call__
return self.kernel(*wargs, **kwargs, grid=self.grid)
File "/workspace/triton/python/triton/code_gen.py", line 988, in __call__
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
File "/workspace/triton/python/triton/code_gen.py", line 956, in add_to_cache
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages,
File "/workspace/triton/python/triton/code_gen.py", line 1285, in _warmup
binary = self._compile(**compile)
File "/workspace/triton/python/triton/code_gen.py", line 1320, in _compile
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
IndexError: map::at
I modified the flash attention code to make it work for cases where
q
has sequence length of 1, whilek
,v
has larger length. This is critical for text generation at run time where one token is generated at a time.Since
tl.dot
does not support matrix/vector dot product. And matrix/vector product cannot utilize TensorCore anyway, so I replace it with broadcasting +tl.sum
. But I run into the following error. It isn't very informative for me. Can people help?Here is my code:
BTW, I am using A100 with cuda11.6.