triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.78k stars 1.54k forks source link

MatMul and blocksparse matmul incorrect precision in some shape. #1808

Open Qu-Xiangjun opened 1 year ago

Qu-Xiangjun commented 1 year ago

Using an Triton 2.0.0, Pytorch 2.0.0, Python 3.9.16, Cuda 11.6 on a pc running Centos release 7.4.1708 with an nvidia A100. I using the matmul and blocksparse/matmul ops in https://github.com/openai/triton/tree/main/python/triton/ops . And I using the test code like to [test_matmul.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py) and [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py).

Then I find some problem when I compare the tirton matmul with torch.matmul, the result is different by torch.allclose(atol = 1e-5, rtol=0) as follow:

Matmul Test

the tesing code as follow:

import torch
import triton

M, N, K = 2048, 2048, 2048
torch.manual_seed(0)
a = torch.randn((M,K), device = 'cuda', dtype = torch.float16)
b = torch.randn((K,N), device = 'cuda', dtype = torch.float16)
# compute torch
torch_output = torch.matmul(a, b)
# compute triton
triton_output = triton.ops.matmul(a, b)

# compare
diff = torch.sum(torch.abs(triton_output - torch_output))
print("total difference: {:10f}".format(diff))

if(torch.allclose(triton_output, torch_output, atol = 1e-5, rtol = 0)):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

  1. the diff increasing as the shape increase. I guess it maybe related from cumulative accuracy of the calculation. But when I using M,K,N = 4096,4096,4096 running this code in my machine, it's pass ✅ the allclose function and diff = 0.000000. It's also related with shape? Because only some shape will occur the problem.

  2. Moreover, I had try some special data to test in shape M, N, K = 2048, 2048, 2048.

    • I take the a = torch.ones ,b = torch.ones to run the code, which result is always pass ✅. So in some times this don't related from shape.

    • I take the a = torch.ones ,b = torch.randn to run the code, which every row for the result matrix is same, also same in the incorrect elements.

Blocksparse Matmul Test

The incorrect precision also in blocksparse matmul function. the test code as follow, which only using the forward testing for [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py) :

def sparsify_tensor(x, mask, block):
    ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
    for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
        ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
    return ret

def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None,
                 dtype=torch.float32):
    if data is None:
        data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
    ref_ret = data
    ref_ret = ref_ret * alpha + beta
    ref_ret = ref_ret.half().to(dtype)
    if trans:
        ref_ret = ref_ret.t().requires_grad_()
    ref_ret = ref_ret.detach().requires_grad_()
    tri_ret = ref_ret.clone().detach().requires_grad_()
    return ref_ret, tri_ret

def mask_tensor(x, mask, block, value=0):
    ret = x.clone()
    for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
        ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
    return ret

def test_blocksparsematmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
    seed = 0
    torch.manual_seed(seed)
    is_sdd = MODE == "sdd"
    is_dsd = MODE == "dsd"
    is_dds = MODE == "dds"
    do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
    do_mask = lambda x: mask_tensor(x, layout, BLOCK)

    # create inputs
    # create op
    a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
    b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
    c_shape = (Z, H, M, N)
    shape = {
        "sdd": (M, N),
        "dsd": (a_shape[2], a_shape[3]),
        "dds": (b_shape[2], b_shape[3]),
    }[MODE]
    layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))

    # create data
    a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
    b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
    dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)

    # compute [torch]
    a_ref = do_mask(a_ref) if is_dsd else a_ref
    b_ref = do_mask(b_ref) if is_dds else b_ref
    a_ref.retain_grad()
    b_ref.retain_grad()
    c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                         b_ref.transpose(2, 3) if TRANS_B else b_ref)
    c_ref = do_sparsify(c_ref) if is_sdd else c_ref      
    # dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
    # c_ref.backward(dc_ref) 
    # da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
    # db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad

    # triton result
    a_tri = do_sparsify(a_tri) if is_dsd else a_tri
    b_tri = do_sparsify(b_tri) if is_dds else b_tri
    a_tri.retain_grad()
    b_tri.retain_grad()
    # op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    c_tri = op(a_tri, b_tri)
    # dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
    # c_tri.backward(dc_tri)
    # da_tri = a_tri.grad
    # db_tri = b_tri.grad

    # compare
    print("--------------------------------------------------------------")
    perf = lambda ms: 2 * M * N * K * Z * H * 1e9 / ( ms * 1e-3)
    total_op = 2 * M * N * K * Z * H

    print('''MODE={}, Z={}, H={}, M={}, N={}, K={}, total_op={}. '''
            .format(MODE,Z, H, M, N, K, total_op))

    diff = torch.sum(torch.abs(c_ref - c_tri))
    print('total diff = {:.10f}'.format(diff))

    if(torch.allclose(c_ref, c_tri, atol = 1e-5, rtol = 0)):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

    ms, _, _ = triton.testing.do_bench(lambda: op(a_tri, b_tri), rep = 20)
    print('''Triton: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms), ms))

    ms_torch, _, _ = triton.testing.do_bench(
        lambda: torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                            b_ref.transpose(2, 3) if TRANS_B else b_ref),
        rep = 20
    )
    print('''Torch: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms_torch), ms_torch))

    return perf(ms), perf(ms_torch), diff

test_blocksparsematmul('dds', False, False, 32, torch.float16, Z = 1, H = 2, M = 64, K = 4096, N = 4096)

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

So what could be causing the incorrect precision and how to solute the problem?

992355092 commented 10 months ago

I have also encountered this problem. Have you resolved it?

Qu-Xiangjun commented 10 months ago

Unfortunately, I didn't find the reason.

FelixSchoen commented 6 months ago

I faced a similar issue when playing around with blocksparse matrix multiplication, here is my code:

    import torch
    import triton.ops

    device = torch.device("cuda")
    dtype = torch.float16

    # Parameters
    batch_size = 2
    head_size = 1
    sequence_length = 32
    d_model = 16
    block_size = 16

    use_int_tensors = False
    max_int_val = 20

    # Tensors
    if use_int_tensors:
        tensor_a = torch.randint(1, max_int_val, (head_size, batch_size, sequence_length, d_model), device=device,
                                 dtype=dtype)
        tensor_b = torch.randint(1, max_int_val, (head_size, batch_size, sequence_length, d_model), device=device,
                                 dtype=dtype)
    else:
        tensor_a = torch.rand((head_size, batch_size, sequence_length, d_model), device=device,
                              dtype=dtype)
        tensor_b = torch.rand((head_size, batch_size, sequence_length, d_model), device=device,
                              dtype=dtype)
    identity_matrix = torch.eye(sequence_length, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0).repeat(
        (head_size, batch_size, 1, 1))
    sparsity_layout = torch.tensor([[[1, 1],
                                     [1, 1]],

                                    [[1, 1],
                                     [1, 1]]
                                    ])

    # Triton Matmul
    triton_matmul_sdd = triton.ops.blocksparse.matmul(sparsity_layout, block_size, "sdd", device)
    triton_result = triton_matmul_sdd(tensor_a, torch.transpose(tensor_b, -1, -2))
    triton_matmul_dsd = triton.ops.blocksparse.matmul(sparsity_layout, block_size, "dsd", device)
    triton_identity = triton_matmul_dsd(triton_result, identity_matrix)

    # Conventional Matmul
    conventional_result = torch.matmul(tensor_a, torch.transpose(tensor_b, -1, -2))
    conventional_identity = torch.matmul(conventional_result, identity_matrix)

    assert torch.allclose(conventional_result, conventional_identity)

    # Passes only up to max_int_val ~ 15 if computing with dtype float16 and integer values
    if not torch.allclose(triton_identity, conventional_identity):
        difference = torch.abs(conventional_identity - triton_identity)
        indices = torch.where(difference > 1e-6)

        # Print the differing values
        print("Conventional Identity at differing indices: ", conventional_identity[indices])
        print("Triton Identity at differing indices: ", triton_identity[indices])
        print("Number of differing values: ", len(conventional_identity[indices]))

In this I generate two tensors containing only integer (or floating values if use_int_tensors=False) and multiply the two matrices using torch.matmul and triton.ops.blocksparse.matmul (with a sparsity layout that corresponds to a regular "full" matrix multiplication). When using float32 as dtype I see numerical inaccuracies, with float16 everything seems to work fine.

Am I overlooking something here? For now I'll stick to float16, please let me know if there are other ways of achieving better precision!