LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance.
Apache License 2.0
2.62k
stars
206
forks
source link
question about fp8 version of context_flashattention_nopad.py #479
context_flashattention_nopad_fp16_fp8.txt
we have implemented a f8 version of context_flashattention_nopad.py. the v shape needs to be changed for performance improvement described in https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. however, the current result is not correct, could you help us?
@triton.jit def _fwd_kernel_fp8( Q, K, V, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2)
@torch.inference_mode() def context_attention_fwd_fp8(q, k, v, o, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len, alibi_slopes=None, sliding_window=None):