ROCm / triton

Development repository for the Triton language and compiler
MIT License
83 stars 27 forks source link

[BUG] Discrepancy Between Triton JIT Computed Sum and Torch Sum #446

Closed Ldpe2G closed 4 months ago

Ldpe2G commented 8 months ago

Environment

Issue Description

I have encountered an issue where the sum computed using Triton's JIT does not match the sum computed by PyTorch's torch.sum. This discrepancy occurs when running the provided code snippet, which is intended to calculate the sum of a random tensor.

To Reproduce

Here is a minimal code snippet to reproduce the issue:

import torch
import triton
import triton.language as tl

@triton.jit
def _test_sum(
    Y,  # pointer to the output
    X,  # pointer to the input
    BLOCK_SIZE: tl.constexpr,
):
    cols = tl.arange(0, BLOCK_SIZE)
    x = tl.load(X + cols)
    sum = tl.sum(x)
    tl.store(Y, sum)

torch.manual_seed(0)
x = torch.rand(64, device="cuda", dtype=torch.float32)
y_triton = torch.zeros(1, device="cuda", dtype=torch.float32)
BLOCK_SIZE = 64
_test_sum[(1, )](y_triton, x, BLOCK_SIZE)

y_torch = torch.sum(x)

assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0), (y_triton, y_torch)

Expected Behavior

I would expect the y_triton tensor to be very close to y_torch tensor, with differences within the acceptable tolerance set by torch.allclose (atol=1e-2, rtol=0).

Actual Behavior

However, the assertion fails, which indicates that the sums do not match within the specified tolerances. The actual values returned are:

Could you please look into this discrepancy?

Thank you for your assistance.

zhanglx13 commented 8 months ago

@Ldpe2G Thanks for reporting the issue. reduceOp is a known issue on Navi GPUs. We are working on it. Will keep you posted.

zhanglx13 commented 4 months ago

@joviliast What is the current status of reduction op on Navi3?

joviliast commented 4 months ago

@zhanglx13 It should be supported for now.

Tried to reproduce this case: On https://github.com/ROCm/triton/tree/triton-mlir and https://github.com/openai/triton/tree/main results are exactly the same on the different sizes:

foof

zhanglx13 commented 4 months ago

Thank you Illia. @Ldpe2G I'll close this one. We are merging in the upstream triton. For further issues, feel free to open tickets at https://github.com/openai/triton/issues