ROCm / triton

Development repository for the Triton language and compiler
MIT License
83 stars 27 forks source link

RMS Norm achieving poor memory bandwidth on MI300 #422

Closed anupambhatnagar closed 3 months ago

anupambhatnagar commented 8 months ago

The RMS norm implementation below achieves well below the peak possible memory bandwidth on MI300. The results can be reproduced using the code below.

When benchmarking BabelStream on the same device I achieved much better performance. Any insights on how to achieve better perf on RMS norm will be appreciated. Thank you!

RMS norm implementation

import torch
import triton
import triton.language as tl

if hasattr(tl, "libdevice"):
    tl_math = tl.libdevice
else:
    tl_math = tl.math

@triton.jit
def _rms_norm_kernel(
    x_ptr,
    h1_ptr,
    w_ptr,
    eps,
    stride,
    N_COLS,
    BLOCK_SIZE: tl.constexpr,
    INCLUDE_WEIGHT: tl.constexpr,
):
    row = tl.program_id(0)
    x_ptr += row * stride
    h1_ptr += row * stride

    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for offset in range(0, N_COLS, BLOCK_SIZE):
        cols = offset + tl.arange(0, BLOCK_SIZE)
        a = tl.load(
            x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
        ).to(tl.float32)
        _mean += a * a
    rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
    for offset in range(0, N_COLS, BLOCK_SIZE):
        cols = offset + tl.arange(0, BLOCK_SIZE)
        mask = cols < N_COLS
        a = tl.load(
            x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
        ).to(tl.float32)
        if INCLUDE_WEIGHT:
            w = tl.load(w_ptr + cols, mask=mask)
            tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
        else:
            tl.store(h1_ptr + cols, a * rstd, mask=mask)

def _rms_norm_forward(x, attn_norm_weights, eps):
    if not x.is_contiguous():
        raise ValueError("data must be contiguous")
    if attn_norm_weights is not None:
        if not attn_norm_weights.is_contiguous():
            raise ValueError("weights must be contiguous")
    out = torch.empty_like(x)
    x_arg = x.reshape(-1, x.shape[-1])
    M, N = x_arg.shape
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    BLOCK_SIZE = max(BLOCK_SIZE, 128)
    BLOCK_SIZE = min(BLOCK_SIZE, 4096)
    # heuristics for number of warps
    num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
    _rms_norm_kernel[(M,)](
        x_arg,
        out,
        attn_norm_weights,
        eps,
        x_arg.stride(0),
        N,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
        INCLUDE_WEIGHT=attn_norm_weights is not None,
    )

    return out

lines = ["triton-jit"]

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],
        x_vals=[512 * i for i in range(2, 32)],
        line_arg="provider",
        line_vals=lines,
        line_names=lines,
        styles=[
            ("blue", "-"),
            ("green", "-"),
            ("orange", "-"),
            ("yellow", "-"),
            ("red", "-"),
            ("pink", "-"),
            ("yellow", "-"),
        ],
        ylabel="GB/s",
        plot_name="rms-norm",
        args={"M": 4096, "dtype": torch.bfloat16},
    )
)

def bench_rms_norm(M, N, dtype, provider, eps=1e-6, device="cuda"):
    # create data
    x_shape = (M, N)
    w_shape = (x_shape[-1],)
    weights = torch.rand(w_shape, dtype=dtype, device="cuda")
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")

    # utility functions
    if provider == "triton-jit":
        y_fwd = lambda: _rms_norm_forward(x, weights, 1e-6)  # noqa E731
    else:
        raise RuntimeError()

    gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6  # noqa E731
    quantiles = [0.5, 0.2, 0.8]

    with torch.inference_mode():
        ms, min_ms, max_ms = triton.testing.do_bench(
            y_fwd, rep=500, quantiles=quantiles
        )

    return gbps(ms), gbps(max_ms), gbps(min_ms)

def main() -> None:
    bench_rms_norm.run(save_path=".", print_data=True)

if __name__ == "__main__":
    main()

RMS Norm memory bandwidth achieved in GB/sec.

N = number of column of the X matrix in the code above.
rms-norm:
          N   Mem BW (GB/s)
0    1024.0   242.845380
1    1536.0   406.980213
2    2048.0   580.727436
3    2560.0   765.160535
4    3072.0   902.283824
5    3584.0  1039.517351
6    4096.0  1249.839195
7    4608.0  1438.787001
8    5120.0  1705.244163
9    5632.0  1801.149440
10   6144.0  1825.396106
11   6656.0  1759.866815
12   7168.0  1702.653270
13   7680.0  1706.342602
14   8192.0  1539.900528
15   8704.0  1536.557202
16   9216.0  1539.256920
17   9728.0  1569.662726
18  10240.0  1594.398303
19  10752.0  1591.708676
20  11264.0  1618.314765
21  11776.0  1635.566654
22  12288.0  1651.720727
23  12800.0  1562.869454
24  13312.0  1615.739393
25  13824.0  1559.854407
26  14336.0  1594.316138
27  14848.0  1552.484248
28  15360.0  1574.615740
29  15872.0  1521.845394

BabelStream Performance

BabelStream
Version: 3.4
Implementation: HIP
Running kernels 100 times
Precision: float
Array size: 1073.7 MB (=1.1 GB)
Total size: 3221.2 MB (=3.2 GB)
elements per lane 4
chunks per block 1
block count 65536
Using HIP device
Driver: 60032830
pciBusID: 8
Validation failed on sum. Error 0.000548607
Sum was 1031357.625 but should be 1030792.125
Function    MB/s        Min (MB/s)  Max         Average     Std. Dev.
Read        5078256.332 3294948.498 5078256.332 4646183.747 475212.302
Write       5120960.761 4347262.408 5120960.761 4950686.092 129090.976
Copy        4582436.884 4404328.357 4582436.884 4494574.336 31689.560
Mul         4570321.508 3736024.556 4570321.508 4454265.825 153658.684
Add         4578054.322 3842006.551 4578054.322 4304480.187 269786.931
Triad       4565575.586 3732177.621 4565575.586 4350104.930 247397.217
Dot         3316880.248 2676907.156 3316880.248 2975013.389 138793.670
bertmaher commented 8 months ago

Definitely curious what’s going on here. This is less than half of peak bandwidth; by contrast on A100 we get essentially 100% of peak bw

scxiao commented 8 months ago

@bertmaher Not sure whether you already figured this out. If you change the kernel signature to:

@triton.jit
def _rms_norm_kernel(
    x_ptr,
    h1_ptr,
    w_ptr,
    eps,
    stride,
    N_COLS: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    INCLUDE_WEIGHT: tl.constexpr,
):

i.e., changing N_COLS to const, there is pretty big performance boost to:

22  12288.0  2247.299783
23  12800.0  2223.278530
24  13312.0  2266.673025
25  13824.0  2274.271469
26  14336.0  2324.921587
27  14848.0  2327.270964
28  15360.0  2374.492701
29  15872.0  2360.028753

though we still have room to improve.

anupambhatnagar commented 8 months ago

@scxiao I made the suggested change you mentioned along with some others and now bandwidth is slightly over to 3TB/s. Analyzing the LLVM IR I found that the loads and stores are not being vectorized at the LLVM IR layer. We are working on a fix that we hope to push upstream. https://github.com/ROCmSoftwarePlatform/triton/pull/445

jerryyin commented 3 months ago

Closing due to inactivity

sreenidhi707 commented 1 month ago

@anupambhatnagar I know this is a closed issue, the Babelstream results that you pasted above, is it from this repo? https://github.com/UoB-HPC/BabelStream I don't see the Read, Write kernels that you have listed in the results, is this some modified version of Babelstream that you are running