Open Abhijit89Kumar opened 1 month ago
@triton.jit def triton_softmax_kernel( logits, output, stride_lm, stride_ln, stride_om, stride_on, M: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, ): pid = tl.program_id(0) offs_l = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_n = tl.arange(0, BLOCK_N) mask = (offs_l[:, None] < M) & (offs_n[None, :] < N) # Load logits logits_ptr = logits + offs_l[:, None] * stride_lm + offs_n[None, :] * stride_ln logits = tl.load(logits_ptr, mask=mask, other=-float('inf')) # Compute softmax max_logits = tl.max(logits, axis=1) exp_logits = tl.exp(logits - max_logits[:, None]) sum_exp_logits = tl.sum(exp_logits, axis=1) output_vals = exp_logits / sum_exp_logits[:, None] # Store results output_ptr = output + offs_l[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(output_ptr, output_vals, mask=mask) THE USAGE OF THIS KERNEL:- # Triton Softmax attn_output = torch.empty_like(attn_weights) M, N = attn_weights.shape[2], attn_weights.shape[3] # q_len and kv_seq_len respectively self.triton_softmax_kernel[grid]( attn_weights, attn_output, attn_weights.stride(0), attn_weights.stride(2), attn_output.stride(0), attn_output.stride(2), M=M, N=N, BLOCK_SIZE=BLOCK_SIZE, BLOCK_N=32 ) THE ERROR:- RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
Need some help here, also can someone explain why this error occurs?