ROCm / triton

Development repository for the Triton language and compiler
MIT License
92 stars 29 forks source link

bfloat16 casting issue in triton_mm - error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast #359

Closed jataylo closed 1 year ago

jataylo commented 1 year ago

@jayfurmanek @zhanglx13 @binarman cc: @dllehr-amd @jithunnair-amd

A few PyTorch unit tests are failing with the message

error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted (core dumped)

After investigation I have narrowed this down to any torch workload while enables triton gemms for matmul using bfloat16 tensors.

Here is a triton reproducer which has a passing fp16 matmul and failing bf16 matmul.

Reproducer:

import torch
import math
import random
from torch import empty_strided
import triton.language as tl
import triton
from torch._dynamo.testing import rand_strided

@triton.jit
def triton_fn(arg_A, arg_B, out_ptr0):
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = False
    ALLOW_TF32 : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    B_PROLOGUE_CAST_TYPE : tl.constexpr = tl.float16
    BLOCK_M : tl.constexpr = 32
    BLOCK_N : tl.constexpr = 32
    BLOCK_K : tl.constexpr = 32

    A = arg_A
    B = arg_B

    M = 8
    N = 2
    K = 2
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 2
    stride_ak = 1
    stride_bk = 8
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        if B_PROLOGUE_CAST_TYPE is not None:
            b = b.to(B_PROLOGUE_CAST_TYPE)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + (8*idx_m)
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)

@triton.jit
def triton_bf_fn(arg_A, arg_B, out_ptr0):
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = False
    ALLOW_TF32 : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    B_PROLOGUE_CAST_TYPE : tl.constexpr = tl.bfloat16
    BLOCK_M : tl.constexpr = 32
    BLOCK_N : tl.constexpr = 32
    BLOCK_K : tl.constexpr = 32

    A = arg_A
    B = arg_B

    M = 8
    N = 8
    K = 2
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 2
    stride_ak = 1
    stride_bk = 8
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        if B_PROLOGUE_CAST_TYPE is not None:
            b = b.to(B_PROLOGUE_CAST_TYPE)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + (8*idx_m)
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)

from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
#import torch._inductor.kernel.mm_common

#fp16
arg0_1 = rand_strided((8, 2), (2, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((2, 8), (8, 1), device='cuda:0', dtype=torch.uint8)
buf0 = empty_strided((8, 8), (8, 1), device='cuda', dtype=torch.float16)
stream0 = get_cuda_stream(0)
triton_fn[(1,)](arg0_1, arg1_1, buf0, stream=stream0)
print("fp16 kernel completed")

#bf16
arg0_1 = rand_strided((8, 2), (2, 1), device='cuda:0', dtype=torch.bfloat16)
arg1_1 = rand_strided((2, 8), (8, 1), device='cuda:0', dtype=torch.int8)
buf0 = empty_strided((8, 8), (8, 1), device='cuda', dtype=torch.bfloat16)
stream0 = get_cuda_stream(0)
triton_bf_fn[(1,)](arg0_1, arg1_1, buf0, stream=stream0)
print("bf16 kernel completed")

For additional context PyTorch's triton matmul codegen is here https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L31

binarman commented 1 year ago

I was not able to reproduce this particular error on mi100 and mi210 and current ToT of triton, could you give more details on your environment:

P.s. I've tried to change int8 to uint8 and got different error:

loc("./fail.py":135:21): error: 'llvm.uitofp' op result #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'i16'
Pass execution failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted (core dumped)
jataylo commented 1 year ago

@binarman Thank you for taking a look at this, just confirmed that this reproducer passes with TOT triton, I had been using Aug 29th commit of triton from our pytorch 2.1 branch.

Once we merge: https://github.com/ROCmSoftwarePlatform/triton/pull/354 and https://github.com/ROCmSoftwarePlatform/triton/pull/296 I will bump forward our triton pin to resolve this.