triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.94k stars 1.57k forks source link

tl.dot with float16 not working on V100 ; works on A100 #3478

Open apd10 opened 6 months ago

apd10 commented 6 months ago

Minimum replicating example:

import torch
import triton
import triton.language as tl

TYPE = torch.float32  # works
#TYPE = torch.float16   # does not work
@triton.autotune(
    configs = [
        triton.Config({'BLOCK_SIZE_M': 32})
    ],
    key = ['M']
)
@triton.jit
def simple_block_dot(a_ptr, b_ptr, c_ptr, M, stride_am, stride_an, stride_bm, stride_bn, stride_cm, stride_cn,
                     BLOCK_SIZE_M: tl.constexpr):
    pid = tl.program_id(axis=0)
    num = tl.cdiv(M, BLOCK_SIZE_M)
    pid_m = pid // num
    pid_n = pid % num

    a_addr = a_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]) * stride_am + (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[None, :]) * stride_an
    b_addr = b_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]) * stride_am + (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[None, :]) * stride_an
    c_addr = c_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]) * stride_am + (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[None, :]) * stride_an

    a = tl.load(a_addr)
    b = tl.load(b_addr)
    c = tl.dot(a, b)
    tl.store(c_addr, c)

def block_dot(a, b):
    assert a.shape == b.shape, "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    assert a.shape[0] == a.shape[1]

    c = torch.empty((M, M), device = a.device, dtype = a.dtype)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M'])**2, )

    simple_block_dot[grid] (a, b, c, M, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))

    return c

M = 32 # equal to block size for verification
a = torch.randn((M,M), device='cuda', dtype = TYPE)
b = torch.randn((M,M), device='cuda', dtype = TYPE)

c = block_dot(a,b)
print("a", a)
print("b", b)
print("triton", c)
print("torch", torch.matmul(a,b))

if torch.allclose(c, torch.matmul(a,b), atol=1e-2, rtol=1e-3):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")
V100 has cuda driver 12.0,
ptxas, nvdisasm, cuobjdump are all 12.3 ( which is downloaded along with triton installation)
other libraries
conda list | grep cuda
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.3.101                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
pytorch                   2.2.1           py3.12_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
A100 has version 12.2  
ptxas, nvdisasm, cuobjdump are all 12.3 ( which is downloaded along with triton installation)
(mlpmixer) apd10@terminator8:~ $ conda list | grep cuda
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvcc                 12.3.107                      0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.3.101                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cudatoolkit               11.7.0              hd8887f6_10    nvidia
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
jlebar commented 6 months ago

V100 is not supported per the README.

Jokeren commented 6 months ago

@apd10 What's exactly the error? Unfortunately I don't have access to V100

apd10 commented 6 months ago

@Jokeren The output is partially correct and partially zeros.

(latest) apd10@yogi:~  $ python3 a.py 
a tensor([[-8.2092e-02,  9.5215e-01,  1.0967e+00,  ..., -8.1641e-01,
         -2.6270e-01, -1.6172e+00],
        [-4.1089e-01,  2.1698e-02, -1.6387e+00,  ..., -1.3123e-01,
         -1.4072e+00,  1.9365e+00],
        [-1.5784e-01, -3.7573e-01,  4.2686e-03,  ..., -1.4424e-04,
         -3.9575e-01,  4.1333e-01],
        ...,
        [-1.7256e+00, -5.0635e-01,  4.1992e-01,  ...,  1.9072e+00,
         -4.9097e-01,  2.6196e-01],
        [ 2.2852e+00,  2.2363e+00, -2.3027e+00,  ..., -1.7900e+00,
         -5.7471e-01,  1.4600e-01],
        [-7.9297e-01,  6.9873e-01, -2.4976e-01,  ..., -8.4717e-01,
         -1.0605e+00, -1.0077e-01]], device='cuda:0', dtype=torch.float16)
b tensor([[-1.4580,  0.2881, -1.6797,  ..., -1.0205, -2.5254,  1.0303],
        [-0.5820,  0.2014,  0.7690,  ..., -0.3772,  0.7280,  0.0329],
        [-1.1582,  0.3062,  0.2717,  ...,  1.0732,  0.0172,  0.6890],
        ...,
        [ 1.3320,  0.2888, -0.7490,  ...,  1.5488, -0.2100, -0.3918],
        [-0.0621, -0.4038, -0.0346,  ...,  1.1074,  1.1211,  0.1573],
        [-2.1738,  0.1888, -0.5669,  ...,  2.0918,  0.3550, -1.6299]],
       device='cuda:0', dtype=torch.float16)
triton tensor([[ 2.6328,  4.1328,  4.3555,  ...,  0.0000,  0.0000,  0.0000],
        [-3.1406,  2.7148,  3.7207,  ...,  0.0000,  0.0000,  0.0000],
        [-2.9160,  1.7646,  2.2832,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-3.4668, -5.3555, -3.6055,  ...,  0.0000,  0.0000,  0.0000],
        [-3.0742,  3.8672,  3.8320,  ...,  0.0000,  0.0000,  0.0000],
        [-3.0527,  3.3574,  1.1143,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.float16)
torch tensor([[  2.6328,   4.1328,   4.3555,  ...,  -1.6572,  -2.3652,   7.0625],
        [ -3.1406,   2.7148,   3.7207,  ...,   8.2344, -10.3047,  -0.7998],
        [ -2.9160,   1.7646,   2.2832,  ...,   3.6875,  -1.4990,   1.2070],
        ...,
        [ -3.4668,  -5.3555,  -3.6055,  ...,   3.8242,   7.5625,   0.7969],
        [ -3.0742,   3.8672,   3.8320,  ..., -11.5078,   0.5552,   5.6914],
        [ -3.0527,   3.3574,   1.1143,  ...,  -1.3135,   1.6162,   7.4492]],
       device='cuda:0', dtype=torch.float16)
❌ Triton and Torch differ
aliencaocao commented 6 months ago

@jlebar per readme, nightly supports V100.

hforoughmand commented 5 months ago

Thank you for your response. It seems that the nightly build fixes the problem, as I cannot see any difference at least in the output. But I still receive the error. I can assume the difference is due to the difference in the precision.

triton_output_with_fp16_inputs=tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3984,  24.4531, -32.3438], 
        [  6.3555, -19.6094,  34.0938,  ...,  -5.8945,   5.2891,   6.8867],                                                       
        [-32.0625,   5.9492,  15.3984,  ..., -21.3906, -23.9844, -10.1328],                                                       
        ...,                                                                                                                      
        [ -5.7031,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],                                                       
        [ 25.5000,  24.3281,  -8.4688,  ..., -18.9375,  32.5312, -29.9219],                                                       
        [ -5.3477,   4.9844,  11.8906,  ...,   5.5898,   6.4023, -17.3125]],
       device='cuda:0', dtype=torch.float16)   
torch_output_with_fp16_inputs=tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3906,  24.4531, -32.3438],                         
        [  6.3516, -19.6094,  34.0938,  ...,  -5.8906,   5.2812,   6.8828],
        [-32.0625,   5.9531,  15.3984,  ..., -21.4062, -23.9844, -10.1328],
        ...,                                                     
        [ -5.7070,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],
        [ 25.5000,  24.3438,  -8.4609,  ..., -18.9375,  32.5312, -29.9219],                                 
        [ -5.3477,   4.9805,  11.8828,  ...,   5.5859,   6.4023, -17.3125]],
       device='cuda:0', dtype=torch.float16)                                                                                      
❌ Triton and Torch differ                                       
triton_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  14.0859,  26.6875, -18.0469],  
        [ 10.0000,  37.0000,  -5.5664,  ..., -15.5000,   3.4609,  39.6562],
        [ 19.5625,  -3.0078, -20.0469,  ..., -16.1094, -60.1562, -10.9062],
        ...,
        [ 15.6562, -53.9375, -54.9375,  ..., -24.2656,  33.7500,  30.8438],
        [-13.3125,   0.7686, -24.3750,  ...,  -6.3984, -29.9375,  11.2188],
        [ 16.8594,  -4.5977,  19.8438,  ...,  -9.5859, -28.9688, -33.7500]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  14.0859,  26.6875, -18.0469],
        [ 10.0000,  37.0000,  -5.5664,  ..., -15.5000,   3.4609,  39.6562],
        [ 19.5625,  -3.0078, -20.0469,  ..., -16.1094, -60.1562, -10.9062],
        ...,
        [ 15.6562, -53.9375, -54.9375,  ..., -24.2656,  33.7500,  30.8438],
        [-13.3125,   0.7686, -24.3750,  ...,  -6.3984, -29.9375,  11.2188],
        [ 16.8594,  -4.5977,  19.8438,  ...,  -9.5859, -28.9688, -33.7500]],
       device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match

I also receive the following error, which could have similar reason:

Traceback (most recent call last):
  File ".../05-layer-norm.py", line 376, in <module>
    test_layer_norm(1151, 8192, torch.float16)
  File ".../05-layer-norm.py", line 318, in test_layer_norm
    assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)
AssertionError