triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.78k stars 1.54k forks source link

Precision issue in Triton kernel: zero gradients for k and v when using fp32 for intermediate calculations with fp16 inputs #4701

Open uniartisan opened 1 week ago

uniartisan commented 1 week ago

Description

In our RWKV6 implementation using Triton for CUDA, we've discovered a critical precision-related issue in the fused_recurrent_rwkv6_bwd_kernel_dkv kernel. When inputs are in fp16 and intermediate calculations are performed in fp32 before casting back to fp16, the gradients for k and v become zero. Importantly, this issue does not occur when inputs are in bf16 and intermediate calculations are in fp32.

Affected Code

The issue occurs in the following kernel:

@triton.jit
def fused_recurrent_rwkv6_bwd_kernel_dkv(
    # ... (parameters)
    TLTYPE: tl.constexpr,  # data type
):
    # ... (kernel implementation)
    b_dk *= scale
    b_dv *= scale
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
    tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
    # ... (rest of the kernel)

Problem Details

  1. Input tensors (q, k, v, w, u) are in fp16 format.
  2. The issue manifests differently depending on the dimension D: a. When D = 64, using tl.float16 for TLTYPE does not cause the issue. b. When D = 128, the issue occurs even when using tl.float16 for TLTYPE.
  3. When TLTYPE is set to tl.float32, intermediate calculations are performed in fp32.
  4. Results are then cast back to fp16 when storing to p_dk and p_dv.
  5. This casting operation results in the gradients for k and v becoming zero.
  6. The issue does not occur when inputs are in bf16 and intermediate calculations are in fp32.

Reproduction Steps

  1. Ensure input tensors are in fp16 format.
  2. Test with two configurations: a. Set TLTYPE = tl.float16 and D = 128 b. Set TLTYPE = tl.float32 for both D = 64 and D = 128
  3. Run the backward pass with the following parameters:
    • B = 4, H = 4, T = 1024
    • dtype = torch.float16
    • scale = -1.0
    • use_h = False
    • u_2d = True

Expected Behavior

The gradients for k and v should be non-zero and match the results obtained from a native PyTorch implementation.

Actual Behavior

The gradients for k and v are all zeros when inputs are fp16 and intermediate calculations are fp32.

Additional Information

Questions

  1. Is this a known limitation of Triton's CUDA backend when dealing with fp32 to fp16 conversions, specifically when inputs are in fp16?
  2. Why does this issue not occur with bf16 inputs when using fp32 for intermediate calculations?
  3. Are there any recommended strategies for handling mixed-precision calculations in Triton CUDA kernels to avoid such issues, particularly when working with fp16 inputs?
  4. Could the difference in behavior between CUDA and XPU implementations provide any insights into potential workarounds?

Environment

We would greatly appreciate any insights or suggestions on how to resolve this precision-related issue while maintaining performance, especially considering the different behaviors observed with fp16 and bf16 inputs.

# -*- coding: utf-8 -*-
from typing import Tuple

import torch
import triton
import triton.language as tl

from fla.ops.utils import chunk_global_reversed_cumsum
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

@triton.autotune(
    configs=[
        triton.Config({}, num_warps=2, num_stages=1),
        triton.Config({}, num_warps=4, num_stages=1),
        triton.Config({}, num_warps=8, num_stages=1),
        triton.Config({}, num_warps=16, num_stages=1),
        triton.Config({}, num_warps=2, num_stages=2),
        triton.Config({}, num_warps=4, num_stages=2),
        triton.Config({}, num_warps=8, num_stages=2),
        triton.Config({}, num_warps=16, num_stages=2),
    ],
    key=['K', 'V', 'T']
)
@triton.jit
def fused_recurrent_rwkv6_fwd_kernel(
    q,  # query [B, H, T, K]
    k,  # key [B, H, T, K]
    v,  # value [B, H, T, V]
    w,  # log gate [B, H, T, K]
    u,  # bonus [B, H, K] or [H, K]
    o,  # output [B, H, T, V]
    # initial hidden state initialization [B, H, K, V]
    h0,
    ht,  # final hidden state [B, H, K, V]
    s_k_h,  # stride size: T * K
    s_v_h,  # stride size: T * V
    scale: tl.constexpr,  # K ** -0.5
    B: tl.constexpr,
    H: tl.constexpr,
    T: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
    STORE_FINAL_STATE: tl.constexpr,  # whether to store final state
    REVERSE: tl.constexpr,  # whether to do autoregressive modeling in the reverse direction
    U_2D: tl.constexpr,  # whether u is 2D
    TLTYPE: tl.constexpr,  # data type
):
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_u = i_bh if not U_2D else i_bh % H
    p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if REVERSE else 0)
    p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if REVERSE else 0)
    p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_u = u + i_u * K + tl.arange(0, BK) + i_k * BK

    mask_bk = (i_k * BK + tl.arange(0, BK)) < K
    mask_bv = (i_v * BV + tl.arange(0, BV)) < V
    mask_kv = mask_bv[:, None] & mask_bk[None, :]

    b_h = tl.zeros([BV, BK], dtype=TLTYPE)
    if USE_INITIAL_STATE:
        p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
        b_h += tl.load(p_h0, mask=mask_kv, other=0).to(TLTYPE)

    b_u = tl.load(p_u, mask=mask_bk, other=0).to(TLTYPE)
    for _ in range(0, T):
        b_k = (tl.load(p_k, mask=mask_bk, other=0)* scale).to(TLTYPE)
        b_v = (tl.load(p_v, mask=mask_bv, other=0)* scale).to(TLTYPE)
        b_q = (tl.load(p_q, mask=mask_bk, other=0)* scale).to(TLTYPE)
        b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
        b_w = tl.exp(b_w).to(TLTYPE)
        b_kv = b_k[None, :] * b_v[:, None]
        b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]
        b_o = tl.sum(b_o, axis=1)
        b_h = b_h * b_w[None, :]
        b_h += b_kv
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
        p_q += -K if REVERSE else K
        p_k += -K if REVERSE else K
        p_o += -V if REVERSE else V
        p_v += -V if REVERSE else V
        p_w += -K if REVERSE else K

    if STORE_FINAL_STATE:
        p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)

# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_rwkv6_bwd_kernel_dq(
    # B: B, H: H, T: T, D: d_head
    # NV: number of split in the V dimension. NK: number of split in the K dimension
    k,  # key [B, H, T, V]
    v,  # value [B, H, T, V]
    w,  # log gate [B, H, T, K]
    u,  # bonus [B, H, K] or [H, K]
    do,  # gradient of output [B, H, T, V]
    dq,  # gradient of query [NV, B, H, T, K]
    dq_aux,  # gradient of query_aux [NV, B, H, T, K]

    # initial hidden state initialization [B, H, K, V]
    h0,

    s_k_h,  # stride size: T * K
    s_v_h,  # stride size: T * V

    scale: tl.constexpr,  # K ** -0.5
    B: tl.constexpr,  # B
    H: tl.constexpr,  # H
    T: tl.constexpr,  # T
    K: tl.constexpr,  # K
    V: tl.constexpr,  # V
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
    REVERSE: tl.constexpr,  # whether to do autoregressive modeling in the reverse direction,
    U_2D: tl.constexpr,  # whether u is 2D
    TLTYPE: tl.constexpr,  # data type
):
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_u = i_bh if not U_2D else i_bh % H
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if REVERSE else 0)
    p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if REVERSE else 0)
    p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if REVERSE else 0)
    p_u = u + i_u * K + tl.arange(0, BK) + i_k * BK

    mask_bk = i_k * BK + tl.arange(0, BK) < K
    mask_bv = i_v * BV + tl.arange(0, BV) < V
    mask_kv = mask_bv[:, None] & mask_bk[None, :]
    b_u = tl.load(p_u, mask=mask_bk, other=0).to(TLTYPE)
    b_h = tl.zeros([BV, BK], dtype=TLTYPE)

    if USE_INITIAL_STATE:
        p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
        b_h += tl.load(p_h0, mask=mask_kv, other=0).to(TLTYPE)

    for _ in range(0, T):
        b_k = (tl.load(p_k, mask=mask_bk, other=0) * scale).to(TLTYPE)
        b_v = (tl.load(p_v, mask=mask_bv, other=0) * scale).to(TLTYPE)
        b_kv = b_k[None, :] * b_v[:, None]
        b_do = tl.load(p_do, mask=mask_bv, other=0).to(TLTYPE)
        b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
        b_w = tl.exp(b_w).to(TLTYPE)
        h_q = b_h * b_do[:, None]
        b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)
        b_dq *= scale
        b_dq_aux = tl.sum(h_q, axis=0)
        b_h = b_h * b_w[None, :]
        b_h += b_kv
        tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
        tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)
        p_k += -K if REVERSE else K
        p_do += -V if REVERSE else V
        p_v += -V if REVERSE else V
        p_w += -K if REVERSE else K
        p_dq += -K if REVERSE else K
        p_dq_aux += -K if REVERSE else K

@triton.jit
def fused_recurrent_rwkv6_bwd_kernel_dkv(
    # B: B, H: H, T: T, D: d_head
    # NV: number of split in the V dimension. NK: number of split in the K dimension
    q,  # query [B, H, T, K]
    k,  # key [B, H, T, V]
    v,  # value [B, H, T, V]
    w,  # log gate [B, H, T, K]
    u,  # bonus [B, H, K] or [H, K]

    do,  # gradient of output [B, H, T, V]
    dk,
    dk_aux,
    dv,
    dh0,

    # initial hidden state initialization [B, H, K, V]
    s_k_h,  # stride size: T * K
    s_v_h,  # stride size: T * V

    scale: tl.constexpr,  # K ** -0.5
    B: tl.constexpr,  # B
    H: tl.constexpr,  # H
    T: tl.constexpr,  # T
    K: tl.constexpr,  # K
    V: tl.constexpr,  # V
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
    REVERSE: tl.constexpr,  # whether to do autoregressive modeling in the reverse direction,
    U_2D: tl.constexpr,  # whether u is 2D
    TLTYPE: tl.constexpr,  # data type
):
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_u = i_bh if not U_2D else i_bh % H
    p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
    p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
    p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
    p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
    p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
    p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)

    mask_bk = i_k * BK + tl.arange(0, BK) < K
    mask_bv = i_v * BV + tl.arange(0, BV) < V
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
    if USE_INITIAL_STATE:
        p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
        b_dh = tl.load(p_dh0, mask=mask_kv, other=0).to(TLTYPE)
    else:
        b_dh = tl.zeros([BV, BK], dtype=TLTYPE)

    p_u = u + i_u * K + tl.arange(0, BK) + i_k * BK
    b_u = tl.load(p_u, mask=mask_bk, other=0).to(TLTYPE)

    for _ in range(T - 1, -1, -1):
        b_q = (tl.load(p_q, mask=mask_bk, other=0) * scale).to(TLTYPE)
        b_k = (tl.load(p_k, mask=mask_bk, other=0) * scale).to(TLTYPE)
        b_v = (tl.load(p_v, mask=mask_bv, other=0) * scale).to(TLTYPE)
        b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
        b_do = tl.load(p_do, mask=mask_bv, other=0).to(TLTYPE)
        b_dkv = (b_q[:, None] * b_do[None, :]).to(TLTYPE)
        b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
        tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)
        b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)
        b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)
        b_dk *= scale
        b_dv *= scale
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
        b_dh = b_dh * tl.exp(b_w)[:, None].to(TLTYPE)
        b_dh += b_dkv

        p_q += K if REVERSE else -K
        p_k += K if REVERSE else -K
        p_v += V if REVERSE else -V
        p_w += K if REVERSE else -K
        p_do += V if REVERSE else -V
        p_dk += K if REVERSE else -K
        p_dk_aux += K if REVERSE else -K
        p_dv += V if REVERSE else -V

    if USE_INITIAL_STATE:
        tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)

class FusedRecurrentRWKV6Function(torch.autograd.Function):

    @staticmethod
    @contiguous
    @autocast_custom_fwd
    def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False, u_2d=True, training=True):
        # alias
        q = r
        B, H, T, K, V = *q.shape, v.shape[-1]

        BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
        NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)

        final_state = q.new_empty(B, H, K, V) if output_final_state else None

        o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
        grid = (NV, NK, B * H)
        fused_recurrent_rwkv6_fwd_kernel[grid](
            q, k, v, w, u, o, initial_state, final_state,
            k.stride(1),
            v.stride(1),
            scale,
            B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
            USE_INITIAL_STATE=initial_state is not None,
            STORE_FINAL_STATE=final_state is not None,
            REVERSE=reverse, U_2D=u_2d,
            TLTYPE=tl.float16 if q.dtype == torch.float16 else tl.float32
        )

        o = o.sum(0)
        if initial_state is not None:
            initial_state = initial_state.clone()
        if training:
            ctx.save_for_backward(q, k, v, w, u, initial_state, o)
            ctx.scale = scale
            ctx.reverse = reverse
            ctx.u_2d = u_2d
        # we do not need the gradient of the final state from the next chunk
        # similiar to Trunctated BPTT
        if final_state is not None:
            final_state = final_state.detach()
        return o.to(q.dtype), final_state

    @staticmethod
    @contiguous
    @autocast_custom_bwd
    def backward(ctx, do, dht=None):
        q, k, v, w, u, initial_state, o = ctx.saved_tensors
        B, H, T, K, V = *q.shape, v.shape[-1]
        scale = ctx.scale
        u_2d = ctx.u_2d
        TLTYPE = tl.float16 if q.dtype == torch.float16 else tl.float32
        TORCHTYPE = torch.float16 if q.dtype == torch.float16 else torch.float32

        BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
        NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
        num_stages = 1
        num_warps = 1
        dq = q.new_empty(NV, B, H, T, K, dtype=TORCHTYPE)
        dq_aux = torch.empty_like(dq)
        grid = (NV, NK, B * H)

        fused_recurrent_rwkv6_bwd_kernel_dq[grid](
            k, v, w, u, do, dq, dq_aux, initial_state,
            q.stride(1),
            v.stride(1),
            scale,
            B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
            USE_INITIAL_STATE=initial_state is not None,
            REVERSE=ctx.reverse, U_2D=u_2d,
            TLTYPE=TLTYPE,
            num_warps=num_warps,
            num_stages=num_stages
        )
        dq = dq.sum(0).to(q)
        dq_aux = dq_aux.sum(0)

        BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
        NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)

        dk = q.new_empty(NV, B, H, T, K, dtype=TORCHTYPE)
        dk_aux = q.new_empty(NV, B, H, T, K, dtype=TORCHTYPE)
        dv = q.new_empty(NK, B, H, T, V, dtype=TORCHTYPE)
        dh0 = (torch.zeros_like(initial_state) + (dht if dht is not None else 0.)) if initial_state is not None else dht
        grid = (NV, NK, B * H)
        fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
            q, k, v, w, u, do, dk, dk_aux, dv, dh0,
            q.stride(1),
            v.stride(1),
            scale,
            B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
            num_warps=num_warps,
            num_stages=num_stages,
            USE_INITIAL_STATE=((initial_state is not None) or (dht is not None)),
            REVERSE=ctx.reverse,
            U_2D=u_2d,
            TLTYPE=TLTYPE, # HERE! CHANGE TO tl.flaot16
        )
        dk = dk.sum(0).to(k)
        dv = dv.sum(0).to(v)
        dk_aux = dk_aux.sum(0)

        dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k * scale)[:, :, 0:-1]
        dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
        dw = chunk_global_reversed_cumsum(dw).to(w)

        if u_2d:
            du = ((do * scale * v ).sum(-1)[..., None] * k * (scale**2) * q).sum([0, -2]).to(u)
        else:
            du = ((do * v * scale).sum(-1)[..., None] * k * (scale**2) * q).sum(-2).to(u)
        return dq, dk, dv, dw, du, None, dh0 if initial_state is not None else None, None, None, None, None

def fused_recurrent_rwkv6(
    r: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    w: torch.Tensor,
    u: torch.Tensor,
    scale: float = -1.0,
    initial_state: torch.Tensor = None,
    output_final_state: bool = False,
    reverse: bool = False,
    training: bool = True,
    causal: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Args:
        r (torch.Tensor):
            reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
        k (torch.Tensor):
            keys of shape `(B, H, T, K)`
        v (torch.Tensor):
            values of shape `(B, H, T, V)`
        w (torch.Tensor):
            data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
        u (torch.Tensor):
            bonus of shape `(H, K)` or `(B, H, K)` for each head.
        scale (Optional[int]):
            Scale factor for the RWKV6 attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `(B, H, K, V)`. Default: `None`.
        output_final_state (Optional[bool]):
            Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
    """
    if scale == -1.0:
        scale = r.shape[-1] ** -0.5
    u_2d = True if u.dim() == 2 else False
    o, final_state = FusedRecurrentRWKV6Function.apply(
        r, k, v, w, u, scale, initial_state, output_final_state, reverse, u_2d, training)
    return o, final_state

if __name__ == '__main__':
    import torch.nn.functional as F
    device = 'cuda'
    B = 4
    H = 4
    T = 1024
    D = 128
    dtype = torch.float16
    scale = -1.0
    use_h = False
    torch.manual_seed(42)

    def get_err_ratio(x, y):
        err = (x-y).flatten().square().mean().sqrt().item()
        base = (x).flatten().square().mean().sqrt().item()
        return err / base

    # if dtype == torch.float16 and scale == 1.0:
    #     return
    torch.manual_seed(42)
    atol = 1e-3 if dtype == torch.float else 1e-2

    q = torch.randn(B, H, T, D, device=device).to(dtype).requires_grad_(True)
    k = torch.randn(B, H, T, D, device=device).to(dtype).requires_grad_(True)
    v = torch.randn(B, H, T, 2*D, device=device).to(dtype).requires_grad_(True)
    w = F.logsigmoid(torch.randn(B, H, T, D, device=device)).to(dtype).requires_grad_(True)
    u = torch.randn(H, D, device=device).to(dtype).requires_grad_(True)

    do = torch.rand_like(v, device=device)
    h = torch.randn(B, H, D, 2*D, device=device, dtype=dtype, requires_grad=True)

    tri_o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=scale, initial_state=h if use_h else None, output_final_state=use_h)
    tri_o.backward(do)
    tri_dq, q.grad = q.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dv, v.grad = v.grad.clone(), None
    tri_dw, w.grad = w.grad.clone(), None
    tri_du, u.grad = u.grad.clone(), None
    if use_h:
        tri_dh, h.grad = h.grad.clone(), None
    print(tri_dv, tri_dk)
uniartisan commented 4 days ago
# -*- coding: utf-8 -*-

# code adapted from
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html

from typing import Optional

import torch
import triton
import triton.language as tl

from fla.utils import contiguous
import pdb

# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=[
        # triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=0, num_warps=8),
        # triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=0, num_warps=2),
        triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=0, num_warps=2),
        # Good config for fp8 inputs.
        # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
        # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
        # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
        # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=0, num_warps=4)
    ],
    key=['M', 'N', 'K'],
)
@triton.heuristics({
    'HAS_INPUT': lambda args: args['input'] is not None,
    'HAS_ALPHA': lambda args: args['alpha'] is not None,
    'HAS_BETA': lambda args: args['beta'] is not None
})
@triton.jit
def matmul_kernel(
    # Pointers to matrices
    a,
    b,
    c,
    input,
    alpha,
    beta,
    # Matrix dimensions
    M,
    N,
    K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `s_am` is how much to increase `a`
    # by to get the element one row down (A has M rows).
    s_ap,
    s_am,
    s_ak,
    s_bp,
    s_bk,
    s_bn,
    s_cp,
    s_cm,
    s_cn,
    # Meta-parameters
    #BP: tl.constexpr,
    BM: tl.constexpr,
    BK: tl.constexpr,
    BN: tl.constexpr,
    G: tl.constexpr,
    ACTIVATION: tl.constexpr,
    HAS_INPUT: tl.constexpr,
    HAS_ALPHA: tl.constexpr,
    HAS_BETA: tl.constexpr
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    NM, NN = tl.num_programs(1), tl.num_programs(2)
    i_p, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_m, i_n = tl.swizzle2d(i_m, i_n,  NM, NN, G)
    tl.static_print("i_p:", i_p)

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `p_a` is a block of [BM, BK] pointers
    # `p_b` is a block of [BK, BN] pointers
    # See above `Pointer Arithmetic` section for details
    o_am = (i_m * BM + tl.arange(0, BM)) % M
    o_bn = (i_n * BN + tl.arange(0, BN)) % N
    o_k = tl.arange(0, BK)

    # here !
    p_a = a  + i_p * s_ap +(o_am[:, None] * s_am + o_k[None, :] * s_ak)
    p_b = b  + i_p* s_bp + (o_k[ :, None] * s_bk + o_bn[None, :] * s_bn )

    b_acc = tl.zeros((BM, BN), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BK)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0).to(tl.float32)
        b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0).to(tl.float32)
        # We accumulate along the K dimension.
        b_acc += tl.dot(b_a, b_b, allow_tf32=False)
        # Advance the ptrs to the next K block.
        p_a += BK * s_ak
        p_b += BK * s_bk

    o_cm = i_m * BM + tl.arange(0, BM)
    o_cn = i_n * BN + tl.arange(0, BN)
    mask =  ( o_cm[:, None] < M) & (o_cn[None, :] < N )

    b_c = b_acc
    #You can fuse arbitrary activation functions here
    #while the b_acc is still in FP32!
    if ACTIVATION == "leaky_relu":
        b_c = leaky_relu(b_c)
    if HAS_ALPHA:
        b_c *= tl.load(alpha)
    if HAS_INPUT:
        # here 批次维度加上
        p_i = input + i_p * s_cp + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]
        b_i = tl.load(p_i, mask=mask, other=0.0).to(tl.float32)
        if HAS_BETA:
            b_i *= tl.load(beta)
        b_c += b_i

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    p_c = c + i_p * s_cp + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]

    tl.store(p_c, b_c.to(c.dtype.element_ty) )

# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01 * x)

@contiguous
def addmm(
    input: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    alpha: Optional[float] = None,
    beta: Optional[float] = None,
    inplace: Optional[bool] = False
) -> torch.Tensor:
    #assert a.shape[2] == b.shape[1], 'Incompatible dimensions (A: {}x{}x{}, B: {}x{}x{})'.format(*a.shape, *b.shape)

    P, M, K = a.shape
    _, K, N = b.shape
    # Allocates output.
    c = a.new_zeros(P, M, N)
    print(c.shape,c.dtype)#
    print(P)

    def grid(meta): return (triton.cdiv(P, 1), triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
    matmul_kernel[grid](
        a, b, c, input, alpha, beta,
        M, N, K,
        a.stride(0), a.stride(1), a.stride(2),
        b.stride(0), b.stride(1), b.stride(2),
        c.stride(0), c.stride(1), c.stride(2),
        ACTIVATION=None,
    )
    return c
import time
torch.manual_seed(342)
a = torch.randn((3,4,3),device='cpu', dtype=torch.float16)
b = torch.randn((3,3,4),device='cpu', dtype=torch.float16)
c = torch.randn((3,4,4),device='cpu', dtype=torch.float16).uniform_(-1, 1)

xx = addmm(c, a, b)
print(xx.shape,xx)
d = a@b +c
print(xx - d)
export TRITON_DEBUG=1

This code suffers from the same problem, in that only the first batch of data is written in the batch dimension, but everything works fine when simulated using numpy