ROCm / triton

Development repository for the Triton language and compiler
MIT License
89 stars 27 forks source link

'AssertionError("cannoe reassign constxpr full_range in the loop")` Failing pytorch UTs #494

Closed jataylo closed 5 months ago

jataylo commented 8 months ago

Problem Description

Seeing ~15 PyTorch UTs failures at TOT triton-mlir reporting this failure previously hidden by https://github.com/ROCm/triton/issues/412

FAILED [0.1121s] test_torchinductor.py::CudaTests::test_bucketize_add_autotune_cuda
FAILED [0.1734s] test_torchinductor.py::CudaTests::test_bucketize_computed_offsets_cuda
FAILED [0.0533s] test_torchinductor.py::CudaTests::test_bucketize_cuda - torc...
FAILED [0.0458s] test_torchinductor.py::CudaTests::test_bucketize_default_kwargs_cuda
FAILED [0.0462s] test_torchinductor.py::CudaTests::test_bucketize_int_cuda - ...

GPU

AMD Instinct MI250X

ROCm Version

ROCm 5.7.0

Steps to Reproduce

Use rocm/pytorch-nightly:latest image and TOT triton-mlir https://github.com/ROCm/triton/commit/6aa01113db5aaedb99748cc439519c9ea562ab66

Reproducer:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch import device, empty, empty_strided
import triton
import triton.language as tl

# THIS KERNEL IS DEFINED IN TRITON_HELPERS - https://github.com/ROCm/pytorch/blob/rocm6.0_internal_testing/torch/_inductor/triton_helpers.py
# We can potentially add workarounds here if there is no way to solve the issue.
@triton.jit
def bucketize_binary_search(
    values,  # 1D tensor
    offsets_ptr,
    indexing_dtype,
    right,  # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
    OFFSETS_SIZE: int,
    BLOCK_SHAPE,  # tuple/list of block shape
):
    """
    See [Note: Inductor bucketize op]
    """
    low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
    high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
    full_range = OFFSETS_SIZE + 1
    while full_range > 1:
        mid = (high + low) // 2
        mask = mid < OFFSETS_SIZE
        bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
        if right:
            is_above = values >= bucket_upper_bound
        else:
            is_above = values > bucket_upper_bound
        low = tl.where(is_above & mask, mid + 1, low)
        high = tl.where(is_above, high, mid)
        full_range = (full_range + 1) // 2
    return low

# THIS KERNEL IS AUTOMATICALLY GENERATED BY INDUCTOR
@triton.jit
def triton_fn(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = bucketize_binary_search(tmp0, in_ptr1, tl.int32, True, 10, [XBLOCK])
    tl.store(out_ptr0 + (x0), tmp1, None)

arg0_1 = empty_strided((64, 64), (64, 1), device='cuda:0', dtype=torch.int64)
arg1_1 = empty_strided((10, ), (1, ), device='cuda:0', dtype=torch.int32)
buf0 = empty((64, 64), device='cuda', dtype=torch.int32)
test = triton.compile(triton_fn, signature="*i64,*i32,*i32,*i32", constants={"XBLOCK": 4096})

Traceback

Traceback (most recent call last):
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
    generator.visit(fn.parse())
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 299, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 424, in visit_Assign
    values = self.visit(node.value)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1020, in visit_Call
    return self.call_JitFunction(fn, args, kws)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 988, in call_JitFunction
    generator.visit(fn.parse())
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 299, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/conda/envs/py_3.8/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/tmp/triton/python/triton/compiler/code_generator.py", line 759, in visit_While
    assert _is_triton_tensor(loop_defs[name]), f'cannoe reassign constxpr {name} in the loop'
AssertionError: cannoe reassign constxpr full_range in the loop
jataylo commented 8 months ago

This was passing in November commit https://github.com/ROCmSoftwarePlatform/triton/commit/e8a35b3968780e48df1374482d56cc6cdbb9e351

xiaohuguo2023 commented 8 months ago

uisng rocm/pytorch-nightly:latest, with fresh container and only mount fresh cloned triton, I have got below error on MI250X: it break all triton tutorial

it seems break all tutorials ?
(py_3.8) root@hyd-7c-ZT09-02:/home/test/triton/python/tutorials# python 03-matrix-multiplication.py
Traceback (most recent call last):
  File "03-matrix-multiplication.py", line 156, in <module>
    import triton.language as tl
  File "/home/test/triton/python/triton/language/__init__.py", line 4, in <module>
    from . import math
  File "/home/test/triton/python/triton/language/math.py", line 5, in <module>
    from . import core
  File "/home/test/triton/python/triton/language/core.py", line 8, in <module>
    from .._C.libtriton.triton import ir
ModuleNotFoundError: No module named 'triton._C.libtriton'

with rocm/pytorch:latest, it has no issue at all

xiaohuguo2023 commented 8 months ago

This should fix the above tests and reproducer. https://github.com/xiaohuguo2023/tritontest/blob/main/reproducer_bs.py and https://github.com/xiaohuguo2023/pytorch/tree/pt-inductorUT-fix

jataylo commented 8 months ago

Thanks @xiaohuguo2023 with this workaround we pass these tests.

Please keep us in the loop any findings from the investigation of why this only fails for us and used to pass for us so we can best decide upstream strategy. If we adopt this in PyTorch we would have to have conditional implementations for ROCm/NV if there is no fix at triton level.

xiaohuguo2023 commented 8 months ago

the latest upstream openai upstream has changed triton.compile interface

in triton-mlir: def compile(fn, **kwargs):

In upstream openai: def compile(src, target=None, options=None):

jataylo commented 7 months ago

https://github.com/xiaohuguo2023/pytorch/commit/dd960611bca349db97e701d14754b83a97c0b8f0

@xiaohuguo2023 to submit PR with the binary search change.

jataylo commented 6 months ago

@xiaohuguo2023 note that this one is PASSING with upstream backend at commit https://github.com/openai/triton/commit/a9bc1a36470eefafe0e2ab2503b8698f1e89e7e3.