Open uniartisan opened 1 week ago
# -*- coding: utf-8 -*-
# code adapted from
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
import pdb
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
# triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=0, num_warps=8),
# triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=0, num_warps=2),
triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=0, num_warps=2),
# Good config for fp8 inputs.
# triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
# triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
# triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=0, num_warps=4),
# triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=0, num_warps=4)
],
key=['M', 'N', 'K'],
)
@triton.heuristics({
'HAS_INPUT': lambda args: args['input'] is not None,
'HAS_ALPHA': lambda args: args['alpha'] is not None,
'HAS_BETA': lambda args: args['beta'] is not None
})
@triton.jit
def matmul_kernel(
# Pointers to matrices
a,
b,
c,
input,
alpha,
beta,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `s_am` is how much to increase `a`
# by to get the element one row down (A has M rows).
s_ap,
s_am,
s_ak,
s_bp,
s_bk,
s_bn,
s_cp,
s_cm,
s_cn,
# Meta-parameters
#BP: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
G: tl.constexpr,
ACTIVATION: tl.constexpr,
HAS_INPUT: tl.constexpr,
HAS_ALPHA: tl.constexpr,
HAS_BETA: tl.constexpr
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
NM, NN = tl.num_programs(1), tl.num_programs(2)
i_p, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
tl.static_print("i_p:", i_p)
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `p_a` is a block of [BM, BK] pointers
# `p_b` is a block of [BK, BN] pointers
# See above `Pointer Arithmetic` section for details
o_am = (i_m * BM + tl.arange(0, BM)) % M
o_bn = (i_n * BN + tl.arange(0, BN)) % N
o_k = tl.arange(0, BK)
# here !
p_a = a + i_p * s_ap +(o_am[:, None] * s_am + o_k[None, :] * s_ak)
p_b = b + i_p* s_bp + (o_k[ :, None] * s_bk + o_bn[None, :] * s_bn )
b_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BK)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0).to(tl.float32)
b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0).to(tl.float32)
# We accumulate along the K dimension.
b_acc += tl.dot(b_a, b_b, allow_tf32=False)
# Advance the ptrs to the next K block.
p_a += BK * s_ak
p_b += BK * s_bk
o_cm = i_m * BM + tl.arange(0, BM)
o_cn = i_n * BN + tl.arange(0, BN)
mask = ( o_cm[:, None] < M) & (o_cn[None, :] < N )
b_c = b_acc
#You can fuse arbitrary activation functions here
#while the b_acc is still in FP32!
if ACTIVATION == "leaky_relu":
b_c = leaky_relu(b_c)
if HAS_ALPHA:
b_c *= tl.load(alpha)
if HAS_INPUT:
# here 批次维度加上
p_i = input + i_p * s_cp + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]
b_i = tl.load(p_i, mask=mask, other=0.0).to(tl.float32)
if HAS_BETA:
b_i *= tl.load(beta)
b_c += b_i
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
p_c = c + i_p * s_cp + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]
tl.store(p_c, b_c.to(c.dtype.element_ty) )
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
@contiguous
def addmm(
input: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
alpha: Optional[float] = None,
beta: Optional[float] = None,
inplace: Optional[bool] = False
) -> torch.Tensor:
#assert a.shape[2] == b.shape[1], 'Incompatible dimensions (A: {}x{}x{}, B: {}x{}x{})'.format(*a.shape, *b.shape)
P, M, K = a.shape
_, K, N = b.shape
# Allocates output.
c = a.new_zeros(P, M, N)
print(c.shape,c.dtype)#
print(P)
def grid(meta): return (triton.cdiv(P, 1), triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
matmul_kernel[grid](
a, b, c, input, alpha, beta,
M, N, K,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
ACTIVATION=None,
)
return c
import time
torch.manual_seed(342)
a = torch.randn((3,4,3),device='cpu', dtype=torch.float16)
b = torch.randn((3,3,4),device='cpu', dtype=torch.float16)
c = torch.randn((3,4,4),device='cpu', dtype=torch.float16).uniform_(-1, 1)
xx = addmm(c, a, b)
print(xx.shape,xx)
d = a@b +c
print(xx - d)
export TRITON_DEBUG=1
This code suffers from the same problem, in that only the first batch of data is written in the batch dimension, but everything works fine when simulated using numpy
Description
In our RWKV6 implementation using Triton for CUDA, we've discovered a critical precision-related issue in the
fused_recurrent_rwkv6_bwd_kernel_dkv
kernel. When inputs are in fp16 and intermediate calculations are performed in fp32 before casting back to fp16, the gradients fork
andv
become zero. Importantly, this issue does not occur when inputs are in bf16 and intermediate calculations are in fp32.Affected Code
The issue occurs in the following kernel:
Problem Details
tl.float16
forTLTYPE
does not cause the issue. b. When D = 128, the issue occurs even when usingtl.float16
forTLTYPE
.TLTYPE
is set totl.float32
, intermediate calculations are performed in fp32.p_dk
andp_dv
.k
andv
becoming zero.Reproduction Steps
TLTYPE = tl.float16
and D = 128 b. SetTLTYPE = tl.float32
for both D = 64 and D = 128Expected Behavior
The gradients for
k
andv
should be non-zero and match the results obtained from a native PyTorch implementation.Actual Behavior
The gradients for
k
andv
are all zeros when inputs are fp16 and intermediate calculations are fp32.Additional Information
tl.float16
forTLTYPE
with fp16 inputs when D = 128.tl.float16
forTLTYPE
with fp16 inputs when D = 64.TLTYPE
totl.float32
causes the issue for both D = 64 and D = 128.Questions
Environment
We would greatly appreciate any insights or suggestions on how to resolve this precision-related issue while maintaining performance, especially considering the different behaviors observed with fp16 and bf16 inputs.