triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.26k stars 1.48k forks source link

Latest nightly triton causes my custom fused attention kernel to output incorrect results. #4310

Open chengzeyi opened 1 month ago

chengzeyi commented 1 month ago

Hello, guys. Thank you for all your great work on this awesome project! I am currently building a new deep learning acceleration framework with it. But I have some problems with it now. Hope you could give me a help.

I am currently working with nightly torch, which is installed by:

pip3 install --pre -U torch==2.5.0.dev20240629 --index-url https://download.pytorch.org/whl/nightly/cu124

This version of pytorch ships with a quite new version of trion, that is:

pip3 list | grep triton
pytorch-triton                     3.0.0+dedb7bdf33

This triton can be installed with pip3 install --pre -U pytorch_triton==3.0.0+dedb7bdf33 --index-url https://download.pytorch.org/whl/nightly/cu124

Ok, then I have a quite complex fused attention kernel. In fact, it is written with torchinductor as a TritonTemplate, and is expected to have a very impressive performance. Now I have extracted the code the generated from torch_compile_debug folder. I also add a test assert to it to check if the output is correct.

# AOT ID: ['1_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()

# kernel path: /tmp/torchinductor_cheng/gv/cgvvk5jpi2re4luzz4h27s6ixcdnsbkws5xbrhudtdr6s7vaztv3.py
# Source Nodes: [scaled_dot_product_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# scaled_dot_product_attention => flash_attention_lowering
triton_tem_fused__scaled_dot_product_flash_attention_0 = async_compile.triton('triton_tem_fused__scaled_dot_product_flash_attention_0', '''
import triton
import triton.language as tl

@triton.jit
def num_threads():
    return tl.extra.cuda.num_threads()

@triton.jit
def maximum(a, b):

    x = a if a > b else b

    return x

@triton.jit
def maximum_(a, b):

    x = tl.maximum(a, b)

    return x

@triton.jit
def add(a, b):

    x = a + b

    return x

@triton.jit
def _attn_fwd_inner(

                    acc_0,
                    q_0,

                    l_i, m_i,  #
                    K_block_ptr, V_block_ptr,  #

                    start_m, #
                    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,  #
                    STAGE: tl.constexpr,  #
                    N_CTX_Q: tl.constexpr, N_CTX_K: tl.constexpr,  #
                    TILES: tl.constexpr,  #
                    EVEN_N: tl.constexpr,  #
                    allow_tf32: tl.constexpr,
):
    qk_scale: tl.constexpr = tl.full([1], 0.18033688, dtype=acc_0.dtype)

    # range of values handled by this stage
    if STAGE == 1:
        if BLOCK_N <= BLOCK_M:
            lo, hi = 0, start_m * BLOCK_M
        else:
            lo, hi = 0, start_m // (BLOCK_N // BLOCK_M) * BLOCK_N
    elif STAGE == 2:
        if BLOCK_N <= BLOCK_M:
            lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
            lo = tl.multiple_of(lo, BLOCK_M)
        else:
            lo, hi = start_m // (BLOCK_N // BLOCK_M) * BLOCK_N, (start_m + 1) * BLOCK_M
            lo = tl.multiple_of(lo, BLOCK_N)
    # causal = False
    else:
        lo, hi = 0, N_CTX_K
    K_block_ptr = tl.advance(K_block_ptr, (0, lo))

    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        # -- compute qk ----
        if BLOCK_DMODEL == BLOCK_K:
            k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
            K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
            qk = tl.dot(q_0, k, allow_tf32=allow_tf32, out_dtype=acc_0.dtype)
        else:

            k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
            K_block_ptr = tl.advance(K_block_ptr, (BLOCK_K, 0))

            qk = tl.dot(q_0, k, allow_tf32=allow_tf32, out_dtype=acc_0.dtype)

        v_0 = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero")
        V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_K))

        tl.debug_barrier()

        if BLOCK_DMODEL != BLOCK_K:
            K_block_ptr = tl.advance(K_block_ptr, (-TILES * BLOCK_K, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, -TILES * BLOCK_K))

        if EVEN_N:
            if STAGE == 2:
                offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
                offs_n = tl.arange(0, BLOCK_N)
                mask = offs_m[:, None] >= (start_n + offs_n[None, :])

                qk = qk * qk_scale

                qk = tl.where(mask, qk, tl.full([1], -float("inf"), dtype=qk.dtype))
            else:

                qk = qk * qk_scale

        else:
            offs_n = tl.arange(0, BLOCK_N)
            mask = (start_n + offs_n[None, :]) < N_CTX_K
            if STAGE == 2:
                offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
                mask = mask & (offs_m[:, None] >= (start_n + offs_n[None, :]))

            qk = qk * qk_scale

            qk = tl.where(mask, qk, tl.full([1], -float("inf"), dtype=qk.dtype))
        m_ij = maximum_(m_i, tl.reduce(qk, 1, maximum))

        qk = qk - m_ij[:, None]

        m_i_m_ij = m_i - m_ij

        alpha = libdevice.exp2(m_i_m_ij.to(tl.float32)).to(acc_0.dtype)

        acc_0 = acc_0 * alpha[:, None]

        p = libdevice.exp2(qk.to(tl.float32)).to(acc_0.dtype)

        l_ij = tl.reduce(p, 1, add)

        # -- update m_i and l_i

        l_i = l_i * alpha + l_ij

        m_i = m_ij

        p = p.to(v_0.dtype)

        # -- update output accumulator --

        acc_0 += tl.dot(p, v_0, allow_tf32=allow_tf32, out_dtype=acc_0.dtype)

    return (

        acc_0,

        l_i, m_i,
    )

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.template(
    num_stages=2,
    num_warps=4,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=128), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'kernel_name': 'triton_tem_fused__scaled_dot_product_flash_attention_0', 'backend_hash': '3F29DF07E3050BE376A892068BE32F30ED6E1FF9E144E14CA43765EE086B7020', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_tem_fused__scaled_dot_product_flash_attention_0(arg_Q, arg_K, arg_V, out_ptr0):
    sm_scale : tl.constexpr = 0.125
    STAGE : tl.constexpr = 1
    TILES : tl.constexpr = 1
    EVEN_N : tl.constexpr = False
    NUM_STAGES : tl.constexpr = 2
    IS_QUANTIZED : tl.constexpr = False
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = True
    ALLOW_TF32 : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float16
    B_PROLOGUE_CAST_TYPE : tl.constexpr = None
    BLOCK_M : tl.constexpr = 128
    BLOCK_N : tl.constexpr = 128
    BLOCK_K : tl.constexpr = 64
    BLOCK_DMODEL : tl.constexpr = 64
    HAS_INLINE_ASM : tl.constexpr = True
    ENABLE_FAST_MATH : tl.constexpr = False
    CUDA_ARCH : tl.constexpr = 890
    CUDA_VERSION : tl.constexpr = 12040
    USE_FAST_ACCUM : tl.constexpr = False
    MAX_NUM_IMPRECISE_ACC : tl.constexpr = None
    Q = arg_Q
    K = arg_K
    V = arg_V

    Z = 2
    H = 10
    N_CTX_Q = 4096
    N_CTX_K = 77
    D = 64

    stride_qz = 2621440
    stride_qh = 262144
    stride_qm = 64
    stride_qk = 1

    stride_kz = 49280
    stride_kh = 4928
    stride_kn = 64
    stride_kk = 1

    stride_vz = 49280
    stride_vh = 4928
    stride_vk = 64
    stride_vn = 1

    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    # off_z = off_z.to(tl.int64)
    # off_h = off_h.to(tl.int64)

    q_offset = off_z * stride_qz + off_h * stride_qh
    k_offset = off_z * stride_kz + off_h * stride_kh
    v_offset = off_z * stride_vz + off_h * stride_vh
    # o_offset = off_z * stride_qz + off_h * stride_qh

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset,
        shape=(N_CTX_Q, D),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_K),
        order=(1, 0),
    )

    V_block_ptr = tl.make_block_ptr(
        base=V + v_offset,
        shape=(N_CTX_K, D),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_K),
        order=(1, 0),
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(D, N_CTX_K),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_K, BLOCK_N),
        order=(0, 1),
    )

    # O_block_ptr = tl.make_block_ptr(
    #     base=Out + o_offset,
    #     shape=(N_CTX_Q, D),
    #     strides=(stride_om, stride_on),
    #     offsets=(start_m * BLOCK_M, 0),
    #     block_shape=(BLOCK_M, BLOCK_DMODEL),
    #     order=(1, 0),
    # )

    if BLOCK_DMODEL == BLOCK_K:
        q_0 = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero")
    else:

        q_0 = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero")

    # initialize pointer to m and l

    acc_0 = tl.zeros([BLOCK_M, BLOCK_K], dtype=ACC_TYPE)

    m_i = tl.full([BLOCK_M], -float("inf"), dtype=acc_0.dtype)
    l_i = tl.full([BLOCK_M], 1.0, dtype=acc_0.dtype)
    # load scales
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        (

            acc_0,

            l_i, m_i,
        ) = _attn_fwd_inner(

                            acc_0,
                            q_0,

                            l_i, m_i, K_block_ptr, V_block_ptr,  #

                            start_m, #
                            BLOCK_M, BLOCK_DMODEL, BLOCK_N, BLOCK_K,  #
                            4 - STAGE, N_CTX_Q, N_CTX_K,  #
                            TILES,  #
                            EVEN_N,  #
                            allow_tf32=ALLOW_TF32,
                            )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        tl.debug_barrier()
        (

            acc_0,

            l_i, m_i,
        ) = _attn_fwd_inner(

                            acc_0,
                            q_0,

                            l_i, m_i, K_block_ptr, V_block_ptr,  #

                            start_m, #
                            BLOCK_M, BLOCK_DMODEL, BLOCK_N, BLOCK_K,  #
                            2, N_CTX_Q, N_CTX_K,  #
                            TILES,  #
                            EVEN_N,  #
                            allow_tf32=ALLOW_TF32,
                            )

    # epilogue
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    off_z = off_hz // H
    off_h = off_hz % H
    # offs_m = offs_m.to(tl.int64)
    # off_z = off_z.to(tl.int64)
    # off_h = off_h.to(tl.int64)

    idx_m = offs_m[None, None, :, None]
    idx_z = tl.full([1, 1, 1, 1], off_z, dtype=idx_m.dtype)
    idx_h = tl.full([1, 1, 1, 1], off_h, dtype=idx_m.dtype)

    acc_0 = acc_0 / l_i[:, None]

    acc_0 = acc_0[None, None, :, :]
    idx_d = tl.arange(0 * BLOCK_K, 1 * BLOCK_K)[None, None, None, :]
    mask = (idx_z < Z) & (idx_h < H) & (idx_m < N_CTX_Q) & (idx_d < D)
    acc = acc_0

    xindex = idx_d + (64*idx_m) + (262144*idx_h) + (2621440*idx_z)
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, acc.shape)), acc, mask)
''', device_str='cuda')
meta0 = {'sm_scale': 0.125, 'STAGE': 1, 'TILES': 1, 'EVEN_N': False, 'NUM_STAGES': 2, 'IS_QUANTIZED': False, 'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float16', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_DMODEL': 64, 'HAS_INLINE_ASM': True, 'ENABLE_FAST_MATH': False, 'CUDA_ARCH': 890, 'CUDA_VERSION': 12040, 'USE_FAST_ACCUM': False, 'MAX_NUM_IMPRECISE_ACC': None}

import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

async_compile.wait(globals())
del async_compile

def fused_attention_grid(b, h, s, d, meta):
    from torch._inductor.utils import ceildiv as cdiv
    return (cdiv(s, meta["BLOCK_M"]), b * h, 1)

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (2, 10, 4096, 64), (2621440, 262144, 64, 1))
    assert_size_stride(arg1_1, (2, 10, 77, 64), (49280, 4928, 64, 1))
    assert_size_stride(arg2_1, (2, 10, 77, 64), (49280, 4928, 64, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((2, 10, 4096, 64), (2621440, 262144, 64, 1), torch.float16)
        # Source Nodes: [scaled_dot_product_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_tem_fused__scaled_dot_product_flash_attention_0.run(arg0_1, arg1_1, arg2_1, buf0, grid=fused_attention_grid(2, 10, 4096, 64, meta0), stream=stream0)
        del arg0_1
        del arg1_1
        del arg2_1
    return (buf0, )

def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((2, 10, 4096, 64), (2621440, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
    arg1_1 = rand_strided((2, 10, 77, 64), (49280, 4928, 64, 1), device='cuda:0', dtype=torch.float16)
    arg2_1 = rand_strided((2, 10, 77, 64), (49280, 4928, 64, 1), device='cuda:0', dtype=torch.float16)
    fn = lambda: call([arg0_1, arg1_1, arg2_1])
    return print_performance(fn, times=times, repeat=repeat)

def check_output():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((2, 10, 4096, 64), (2621440, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
    arg1_1 = rand_strided((2, 10, 77, 64), (49280, 4928, 64, 1), device='cuda:0', dtype=torch.float16)
    arg2_1 = rand_strided((2, 10, 77, 64), (49280, 4928, 64, 1), device='cuda:0', dtype=torch.float16)
    out = call([arg0_1, arg1_1, arg2_1])[0]

    sdpa_out = torch.ops.aten.scaled_dot_product_attention(arg0_1, arg1_1, arg2_1)
    torch.testing.assert_close(out, sdpa_out, rtol=1e10, atol=1e-2)

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
    check_output()

After running it, The console prints that the assertion failed and the output tensor contains inf values. If I use tl.device_print to see the values of the final acc_0 before tl.store, it seems that the values are Ok, which is quite strange.

However, if I switch to an older triton, for example:

pip3 list | grep triton
pytorch-triton                     3.0.0+a9bc1a3647

This version of triton can be installed by pip3 install --pre -U pytorch_triton==3.0.0+a9bc1a3647 --index-url https://download.pytorch.org/whl/nightly/cu124. It's not very old though.

With this older version of triton. The output tensor does not contain inf and is relatively correct, and does not trigger the assertion failure.

Currently I have no idea about why this happens. I use an RTX 4090 card now. This is how torch.utils.collect_env shows. Do you have any idea?

PyTorch version: 2.5.0.dev20240629+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.38
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      39 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             24
On-line CPU(s) list:                0-23
Vendor ID:                          GenuineIntel
Model name:                         13th Gen Intel(R) Core(TM) i7-13700KF
CPU family:                         6
Model:                              183
Thread(s) per core:                 2
Core(s) per socket:                 12
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           6835.19
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc c
puid pni pclmulqdq vmx ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase ts
c_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  Microsoft
Virtualization type:                full
L1d cache:                          576 KiB (12 instances)
L1i cache:                          384 KiB (12 instances)
L2 cache:                           24 MiB (12 instances)
L3 cache:                           30 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] onnx==1.14.0
[pip3] onnxruntime==1.15.0
[pip3] open-clip-torch==2.23.0
[pip3] pytorch-lightning==1.9.4
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240629+cu124
[pip3] torchao==0.3.1
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==1.2.1
[conda] Could not collect
ThomasRaoux commented 1 month ago

cc: @bertmaher , I'm guessing this is using the new triton 3.0 branch? Not sure how to tell which triton hash is being used there.

chengzeyi commented 1 month ago

cc: @bertmaher , I'm guessing this is using the new triton 3.0 branch? Not sure how to tell which triton hash is being used there.

Thank you for your quick response. This should be the upcoming 3.0 version of triton. The commit hash is in the packaged wheel file name.

chengzeyi commented 1 month ago

Another hint is that if I use fp32 dot accumulation instead of fp16, the output will be correct.