triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered #4577

Open Abhijit89Kumar opened 1 month ago

Abhijit89Kumar commented 1 month ago
@triton.jit
    def triton_softmax_kernel(
        logits, 
        output, 
        stride_lm, 
        stride_ln, 
        stride_om, 
        stride_on, 
        M: tl.constexpr, 
        N: tl.constexpr, 
        BLOCK_SIZE: tl.constexpr,
        BLOCK_N: tl.constexpr,
    ):
        pid = tl.program_id(0)
        offs_l = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        offs_n = tl.arange(0, BLOCK_N)
        
        mask = (offs_l[:, None] < M) & (offs_n[None, :] < N)
        
        # Load logits
        logits_ptr = logits + offs_l[:, None] * stride_lm + offs_n[None, :] * stride_ln
        logits = tl.load(logits_ptr, mask=mask, other=-float('inf'))
        
        # Compute softmax
        max_logits = tl.max(logits, axis=1)
        exp_logits = tl.exp(logits - max_logits[:, None])
        sum_exp_logits = tl.sum(exp_logits, axis=1)
        
        output_vals = exp_logits / sum_exp_logits[:, None]
        
        # Store results
        output_ptr = output + offs_l[:, None] * stride_om + offs_n[None, :] * stride_on
        tl.store(output_ptr, output_vals, mask=mask)

THE USAGE OF THIS KERNEL:-
        # Triton Softmax
        attn_output = torch.empty_like(attn_weights)
        M, N = attn_weights.shape[2], attn_weights.shape[3]  # q_len and kv_seq_len respectively
        self.triton_softmax_kernel[grid](
            attn_weights, attn_output,
            attn_weights.stride(0), attn_weights.stride(2),
            attn_output.stride(0), attn_output.stride(2),
            M=M, N=N,
            BLOCK_SIZE=BLOCK_SIZE,
            BLOCK_N=32
        )

THE ERROR:-
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
Abhijit89Kumar commented 1 month ago

Need some help here, also can someone explain why this error occurs?