triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.27k stars 1.63k forks source link

Segmentation fault in triton==3.0.0 #4389

Open wyang20170113 opened 3 months ago

wyang20170113 commented 3 months ago
from typing import Optional
import torch

import triton
import triton.language as tl

@triton.jit
def _distance_bias_(diagram: tl.tensor,
                    lower_bound: tl.tensor,
                    upper_bound: tl.tensor,
                    weight: tl.tensor,
                    bias: tl.tensor,
                    c_ss: int,
                    ):
    lower_d_mask = diagram[:, :, None] > lower_bound[None, None, :]
    upper_d_mask = diagram[:, :, None] < upper_bound[None, None, :]
    d_mask = lower_d_mask * upper_d_mask
    d_mask = d_mask.to(diagram.dtype)
    tl.static_print(d_mask)
    tl.static_print(weight)

    # uncomment the following two line, work!
    # o = d_mask[:, :, :, None] * weight[None, None, :, :]
    # o = tl.sum(o, axis=2)

    # the following line will cause segmentation fault.
    o = tl.dot(d_mask, weight[None, :, :])

    tl.static_print(o)
    o = o * bias[None, None, :]
    o = tl.sum(o, axis=2)
    o = o / c_ss
    return o

@triton.jit
def distance_bias_fwd_triton(output_ptr,
                             diagram_ptr,
                             lower_bound_ptr,
                             upper_bound_ptr,
                             weight_ptr,
                             bias_ptr,
                             diagram_stride,
                             weight_stride,
                             num_rows, num_cols, num_bins, c_ss,
                             distance_default,
                             BLOCK_ROW_SIZE: tl.constexpr,
                             BLOCK_COL_SIZE: tl.constexpr,
                             BLOCK_BIN_SIZE: tl.constexpr,
                             BLOCK_SS_SIZE: tl.constexpr,
                             INF,
                             ):
    pid = tl.program_id(axis=0)
    block_row = tl.arange(0, BLOCK_ROW_SIZE)
    block_col = tl.arange(0, BLOCK_COL_SIZE)
    diagram_offset = block_row[:, None] * diagram_stride + block_col[None, :]
    diagram_mask = (block_row[:, None] < num_rows) & (block_col[None, :] < num_cols)
    diagram = tl.load(diagram_ptr + diagram_offset, mask=diagram_mask, other=INF)

    block_bins = tl.arange(0, BLOCK_BIN_SIZE)
    lower_bound = tl.load(lower_bound_ptr + block_bins, mask=(block_bins < num_bins), other=INF)
    upper_bound = tl.load(upper_bound_ptr + block_bins, mask=(block_bins < num_bins), other=INF)

    block_ss = tl.arange(0, BLOCK_SS_SIZE)
    weight_offset = block_bins[:, None] * weight_stride + block_ss[None, :]
    weight_mask = (block_bins[:, None] < num_bins) & (block_ss[None, :] < c_ss)
    weight = tl.load(weight_ptr + weight_offset, mask=weight_mask, other=0.0)

    bias = tl.load(bias_ptr + block_ss, mask=(block_ss < c_ss), other=0.0)
    # ou = _distance_bias_(diagram, lower_bound, upper_bound, weight, bias, c_ss, BLOCK_ROW_SIZE, BLOCK_COL_SIZE)
    ou = _distance_bias_(diagram, lower_bound, upper_bound, weight, bias, c_ss)

    tl.store(output_ptr + diagram_offset, ou, mask=diagram_mask)

def distance_bias_fwd(diagram: torch.Tensor,
                      lower_bound: torch.Tensor,
                      upper_bound: torch.Tensor,
                      weight: torch.Tensor,
                      bias: torch.Tensor,
                      INF: float,
                      no_bins: Optional[int] = None,
                      c_ss: Optional[int] = None,
                      ) -> torch.Tensor:
    assert diagram.is_cuda and lower_bound.is_cuda and upper_bound.is_cuda
    assert weight.is_cuda and bias.is_cuda
    num_rows, num_cols = diagram.shape
    if no_bins is None:
        no_bins = weight.shape[0]
    if c_ss is None:
        c_ss = weight.shape[1]
    distance_default = 10
    device, dtype = diagram.device, diagram.dtype
    o = torch.zeros_like(diagram, device=device, dtype=dtype)
    BLOCK_ROW_SIZE = triton.next_power_of_2(num_rows)
    BLOCK_COL_SIZE = triton.next_power_of_2(num_cols)
    grid = lambda meta: (triton.cdiv(num_rows, BLOCK_ROW_SIZE), triton.cdiv(num_cols, BLOCK_COL_SIZE))
    # grid = lambda meta: (1,)
    distance_bias_fwd_triton[grid](o,
                                   diagram,
                                   lower_bound,
                                   upper_bound,
                                   weight,
                                   bias,
                                   diagram.stride(0),
                                   weight.stride(0),
                                   num_rows, num_cols, no_bins, c_ss, distance_default,
                                   BLOCK_ROW_SIZE,
                                   BLOCK_COL_SIZE,
                                   triton.next_power_of_2(no_bins),
                                   triton.next_power_of_2(c_ss),
                                   INF,
                                   num_stages=4,
                                   )
    return o

if __name__ == '__main__':
    torch.manual_seed(42)
    dtype, device = torch.float32, 'cuda'
    inf = 1e8
    min_bin = 3.25
    max_bin = 20.75
    no_bins = 29
    c_ss = 16
    w = torch.randn((no_bins, c_ss), dtype=dtype, device=device, requires_grad=True)
    b = torch.randn((c_ss,), dtype=dtype, device=device, requires_grad=True)

    num_rows, num_cols = 16, 16
    # num_rows, num_cols = 8, 8
    bins = torch.linspace(min_bin, max_bin, no_bins,
                          dtype=dtype, device=device, requires_grad=False)
    squared_bins = bins ** 2
    upper_bins = torch.cat([squared_bins[1:], squared_bins.new_tensor([inf])], dim=-1)

    d = 500 * torch.rand((num_rows, num_cols), dtype=dtype, device=device, requires_grad=False)

    oo = distance_bias_fwd(d, squared_bins, upper_bins, w, b, inf, no_bins, c_ss)

    # print(oo)

    # print(oo_w)
    # print(w)
    # max_diff = torch.max(torch.abs(oo_w - w))
    # print(max_diff)

    pass

Running the above code lead to following error: """ fp32[constexpr[8], constexpr[16], constexpr[64]] fp32[constexpr[64], constexpr[64]] fp32[constexpr[8], constexpr[16], constexpr[64]] Segmentation fault (core dumped) """ It seem that tl.dot is a problem since using the following two line to replace the tl.dot make the code work!!! o = d_mask[:, :, :, None] * weight[None, None, :, :] o = tl.sum(o, axis=2) I am curious what happened and why is that? Thanks very much and appreciate your comments.

plotfi commented 3 months ago

Looks like the insert_element in MMA16816SmemLoader::loadX4 is trying to insert at index 32 when it only has a vector of 4 elements when lowering the following:


    %72 = triton_gpu.local_load %68 : !tt.memdesc<16x16x32xf32, #shared> -> tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> loc(#loc54)
    %73 = triton_gpu.local_load %71 : !tt.memdesc<1x32x16xf32, #shared1> -> tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> loc(#loc56)
    %74 = tt.dot %72, %73, %cst, inputPrecision = tf32 : tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> * tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> -> tensor<16x16x16xf32, #mma> loc(#loc56)
plotfi commented 2 months ago

Crash is happening here:

https://github.com/triton-lang/triton/blob/main/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp#L415-L420

It is crashing because the canonWidth is 32 which goes out of the bounds of the retElems SmallVector that contains the 4 elements for the loadX4. I think if you were using bf16 it would probably be taking the ldmatrix path instead.

I am not sure how to fix this one.