triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.16k stars 1.61k forks source link

Fusing multi-input activation function #984

Open jmc128 opened 1 year ago

jmc128 commented 1 year ago

I'm trying to implement a single kernel which performs matrix multiplication followed by a non-unary activation function f(x, y) = sigmoid(x)*tanh(y). In a neural network library this would implemented as following:

C = matmul(A, B)
C_left, C_right = split(C, axis=1)  # split into left and right halves
C_final = sigmoid(C_left) * tanh(C_right)

One simplification that could possibly be helpful is that any partitioning of the columns of C is fine (for example applying sigmoid to even columns and tanh to odd columns). My initial attempt at writing this in triton took advantage of this by trying to partition at the block level. In this case, a triton implementation would be identical to the standard matmul kernel up until the activation is applied, at which point I'd need to split the [BLOCK_SIZE_M, BLOCK_SIZE_N] accumulator into two partitions, each with shape [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]. As far as I understand, this is not currently possible with triton, so this direction is a dead end.

My next attempt was to initialize two separate [BLOCK_SIZE_M, BLOCK_SIZE_N // 2] accumulators as well as two separate b_ptrs at the corresponding column offsets. At each iteration over K, I perform two separate dot operations. This gives a segmentation fault when implemented (see issue #787). I was able to make it work, but it required adding a redundant load operation, whereby the first operand to the dot operations had to be loaded twice even though this tensor can be shared between both dot operations:

for k in range(0, K, BLOCK_SIZE_K):
    a = tl.load(a_ptrs)

    b_left = tl.load(b_left_ptrs)
    accumulator_left += tl.dot(a, b_left)

    a = tl.load(a_ptrs) # without this line, program seg faults

    b_right = tl.load(b_right_ptrs)
    accumulator_right += tl.dot(a, b_right)

This ran correctly, but the performance is much worse than a naive implementation which runs a matmul kernel and an activation kernel in series.

I'm curious if anyone can think of a way to implement this efficiently with triton.

ptillet commented 1 year ago

Since Triton doesn't have indexing ops right now, I think your second approach is reasonable. Sorry for the segmentation fault. Can you retry now with the new triton-mlir backend?

jmc128 commented 1 year ago

Thanks for the reply. I've tested this with the new backend and it no longer segfaults when removing the line which reloads the values of A. However, removing this line results in worse performance than when keeping it. In either case though, the performance is still worse than running a matmul kernel and an activation kernel in series. After trying to implement this kernel in maybe a dozen different ways, I've come to the conclusion that any implementation which combines two separate accumulators incurs a significant performance penalty. Multiple dots/accumulators are fine so long as they are totally independent (e.g. separately stored to different locations), but if the kernel attempts to store a value which depends on both accumulators then performance takes a hit. My best attempt at makeshift indexing has been to simply store the full block and immediately load the two separate partitions. However, this is no more performant than using two separate kernels. I suppose I won't be able to beat the two kernel approach until indexing is supported in triton.

ptillet commented 1 year ago

Sorry for the delay. I think the dual accumulator approach could work, but you'll need to reduce the block size (and/or increase number of warps) to reduce register spilling. Do you have a repro script we could run?

tlogn commented 5 months ago

@ptillet Hi, there! I've also encountered the issue of poor performance with dual accumulators. Below is my code. Would you please help take a look

import torch

import triton
import triton.language as tl

def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256//2, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64//2, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
    ]

def get_autotune_config():
        return get_cuda_autotune_config()

@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def dual_matmul_fuse_bias_act_kernel(
        a_ptr, b_ptr, c_ptr, b_bias_ptr, c_bias_ptr,
        output_ptr,
        M, N, K,
        stride_am, stride_ak, 
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        ACTIVATION: tl.constexpr, 
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr, 
):
    '''
    output = (A x B) * act(A x C)
    '''
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    c_ptrs = c_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    b_bias_ptrs = b_bias_ptr + offs_bn
    c_bias_ptrs = c_bias_ptr + offs_bn

    accumulator_1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    accumulator_2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        c = tl.load(c_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

        accumulator_1 = tl.dot(a, b, accumulator_1)
        accumulator_2 = tl.dot(a, c, accumulator_2)

        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
        c_ptrs += BLOCK_SIZE_K * stride_bk

    b_bias = tl.load(b_bias_ptrs).to(tl.float32)
    c_bias = tl.load(c_bias_ptrs).to(tl.float32)
    accumulator_1 = accumulator_1 + b_bias
    accumulator_2 = accumulator_2 + c_bias
    if ACTIVATION == "gelu":
        accumulator_2 = gelu(accumulator_2)
    accumulator_1 = accumulator_1.to(tl.float16)
    accumulator_2 = accumulator_2.to(tl.float16)
    accumulator_1 = accumulator_1 * accumulator_2

    offs_outputm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_outputn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    output_ptrs = output_ptr + stride_cm * offs_outputm[:, None] + stride_cn * offs_outputn[None, :]
    output_mask = (offs_outputm[:, None] < M) & (offs_outputn[None, :] < N)
    tl.store(output_ptrs, accumulator_1, mask=output_mask)

@triton.jit
def gelu(x):
    return x * 0.5 * (1.0 + tl.math.erf(x / 1.41421356))

@triton.jit
def gelu_approximate(x):
    return x * 0.5 * (1.0 + tl.math.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

def dual_matmul_fuse_bias_act(a, b, c, b_bias, c_bias, act=""):
    assert a.shape[1] == b.shape[0] == c.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape

    output = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )

    dual_matmul_fuse_bias_act_kernel[grid](
        a, b, c, b_bias, c_bias, 
        output,
        M, N, K,
        a.stride(0), a.stride(1),  
        b.stride(0), b.stride(1),  
        output.stride(0), output.stride(1),  
        ACTIVATION=act,
    )
    return output

configs = []

configs.append(
    triton.testing.Benchmark(
        x_names=["M", "N", "K"],
        x_vals=[128 * i for i in range(2, 33, 2)],
        line_arg="provider", 
        line_vals=["torch", "triton"], 
        line_names=["torch", "Triton"],
        styles=[("green", "-"), ("blue", "-")],
        ylabel="TFLOPS",
        plot_name="matmul-performance-" +("fp16"),
        args={"dtype": torch.float16},
    ))
@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, dtype):
    layer1 = torch.nn.Linear(K, N, dtype=dtype, device=device)
    layer2 = torch.nn.Linear(K, N, dtype=dtype, device=device)
    input = torch.randn(M, K, dtype=dtype, device=device)
    quantiles = [0.5, 0.2, 0.8]
    def torch_forward(input):
        return layer1(input) * torch.nn.functional.gelu(layer2(input))
    def triton_forward(input):
        return dual_matmul_fuse_bias_act(input, 
                                         layer1.weight.data.transpose(0, 1), 
                                         layer2.weight.data.transpose(0, 1), 
                                         layer1.bias.data, 
                                         layer2.bias.data,
                                         act="gelu")

    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_forward(input), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_forward(input), quantiles=quantiles)
    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

if __name__ == '__main__':
    dtype = torch.float16
    device = 'cuda'
    layer1 = torch.nn.Linear(512, 1024, dtype=dtype, device=device)
    layer2 = torch.nn.Linear(512, 1024, dtype=dtype, device=device)

    input = torch.randn(1024, 512, dtype=dtype, device=device)

    output_torch = layer1(input) * torch.nn.functional.gelu(layer2(input))
    output_triton = dual_matmul_fuse_bias_act(input, layer1.weight.data.transpose(0, 1), layer2.weight.data.transpose(0, 1), layer1.bias.data, layer2.bias.data, act="gelu")
    print(torch.allclose(output_torch, output_triton, atol=1e-2))
    print(torch.max(abs(output_torch-output_triton)))

    benchmark.run(print_data=True)