ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 23 forks source link

[PYTORCH UT] Inaccuracies in argmax/argmin UT #505

Closed jataylo closed 3 months ago

jataylo commented 5 months ago

Problem Description

At TOT triton-mlir with pytorch nightly seeing inaccuracies in the follow unit test: torchinductor.py::test_argmax_argmin_with_duplicates_dynamic_shapes_cuda but passes at a previous triton-mlir commit https://github.com/ROCm/triton/commit/6aa01113db5aaedb99748cc439519c9ea562ab66l

AssertionError: Tensor-likes are not equal!
Mismatched elements: 636 / 1028 (61.9%)
Greatest absolute difference: 530 at index (901,)
Greatest relative difference: inf at index (1,)

Initial triage: I was able to minify the UT to reproduce this only with argmax on the single tensor input

def test_argmax_argmin_with_duplicates(self):
    def fn(x):
        return (
            aten.argmax(x, 0),
        )

    # Non-persistent reduction
    t1 = torch.randint(8, size=(1028, 1028))
    self.common(fn, (t1,))

Operating System

-

CPU

-

GPU

AMD Instinct MI250X

ROCm Version

ROCm 5.7.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

jataylo commented 5 months ago

Doing a little more triage work on this to try and get a triton only reproducer to identify if facing some numerical issue then may require some assistance on triton debug side if we can show the problem there cc: @xiaohuguo2023

xiaohuguo2023 commented 5 months ago

I suspect test_core_amd.py::test_reduce1d may be also broken, I will run this test to confirm test_core_amd.py::test_reduce1d

jataylo commented 4 months ago

@xiaohuguo2023 to verify whether reduce1d tests are passing at the triton-pytorch commit, if so we will need to build a harness around triton kernel to evaluate results.

jataylo commented 4 months ago

@xiaohuguo2023 note that this one is still failing with upstream backend at commit https://github.com/openai/triton/commit/a9bc1a36470eefafe0e2ab2503b8698f1e89e7e3. I'll update with instructions on how we can get upstream backend working with inductor shortly

zhanglx13 commented 3 months ago

@jataylo Can you check if this is still failing with upstream?

zhanglx13 commented 3 months ago

Please reopen if it still fails

suxiangM commented 3 months ago

I have also encountered this issue. Is there a solution now? Only the following layoutencountered errors:

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
jataylo commented 3 months ago

Hi @suxiangM no solution on the triton fork AFAIK but these issues are not observed with the AMD backend of upstream triton.

We have a Pytorch PR in review currently to switch to using openai/triton instead of our fork: https://github.com/pytorch/pytorch/pull/121801

suxiangM commented 3 months ago

ok! Thank you for your work and answer!!!