triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.53k stars 1.67k forks source link

Racecheck Bug when tl.min used with tl.sum #4736

Open thumbe3 opened 2 months ago

thumbe3 commented 2 months ago
import os
import torch
import numpy as np
import triton
import triton.language as tl
import triton

@triton.jit
def compute_min_distance_coord(input_ptr: tl.tensor,
                         coord_ptr: tl.tensor,
                         min_cord_idx_ptr: tl.tensor,
                         BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs_input_row = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_coord_row = tl.arange(0, 32)
    offs_coord_idxs = tl.arange(0, 8)

    offs_input = offs_input_row[:, None] * 8 + offs_coord_idxs[None, :]
    offs_coord = offs_coord_row[:, None] * 8 + offs_coord_idxs[None, :]
    input = tl.load(input_ptr + offs_input)
    coord = tl.load(coord_ptr + offs_coord)

    #[BLOCK_SIZE,32,8]
    diff = input[:,None,:] - coord[None,:,:]
    dist_sq = diff * diff
    _, min_coord_idxs = tl.min(tl.sum(dist_sq, axis=-1), axis=-1, return_indices=True)
    tl.store(min_cord_idx_ptr + offs_input_row, min_coord_idxs.to(tl.int32))

input = torch.rand(1<<20, 8, dtype= torch.float32).cuda(0)
coordinates = torch.rand(32, 8,  dtype= torch.float32).cuda(0)
out_min_idxs = torch.zeros([1<<20], dtype= torch.int32).cuda(0)

grid = lambda meta: (triton.cdiv(1<<20, meta['BLOCK_SIZE']),)
compute_min_distance_coord[grid](input, coordinates, out_min_idxs, BLOCK_SIZE=512)
torch.cuda.synchronize()

# Equivalent Numpy Code to check correctness
input_np = input.cpu().numpy()
coord_np = coordinates.cpu().numpy()
diff_sq = np.square(input_np[:,None,:]-coord_np[None,:,:])
out_min_idxs_np = np.argmin(np.sum(diff_sq, axis=-1),axis=-1)
print(np.allclose(out_min_idxs_np, out_min_idxs.cpu().numpy()))

In the above code, I try to find the distance between each element of input with 32 coordinates. And return the coordinate with minimum distance to each input(Might be more easier to understand from the numpy code below). When you run this code with race-check tool of compute-sanitizer using (compute-sanitizer --tool=racecheck python script.py). The following output is shown

========= Error: Race reported between Write access at compute_min_distance_coord+0x5ad20 in /usr/local/lib/python3.10/dist-packages/triton/language/standard.py:237 ========= and Write access at compute_min_distance_coord+0x5ad20 in /usr/local/lib/python3.10/dist-packages/triton/language/standard.py:237 [6136 hazards]

Error seems to be stemming standard.py which seems to be in the min function image

 I am not facing correctness issue with this code at the moment. But I have faced correctness issues with other kernels using similar combination tl.sum with tl.min

lijinpei commented 2 months ago

I have created a WIP patch https://github.com/lijinpei/triton/commit/3fe20ba894bba0e142e2b9f4a24b75df5ddfd194 which solves the provided script.py and fail no case in python/test/unit/language/test_core.py (except 'python/test/unit/language/test_core.py::test_dot[1-128-128-64-4-False-False-chain-dot-ieee-float8e5-float32-1]' already failed on main branch on my machine). Can you try the patch on 'correctness issues with other kernels using similar combination tl.sum with tl.min', or help to provide it as a unit test? I think gate keepers won't accept the patch without a unit test.

Jokeren commented 2 months ago

We likely won't accept your solution even with a unit test. I don't see correctness issues.

Jokeren commented 2 months ago

But I have faced correctness issues with other kernels using similar combination tl.sum with tl.min

Since having data races in this specific case doesn't cause correctness problems for you IIUC, it might be better to provide your code with real issues.

Data races could be triggered by having the same location being accessed by multiple threads with the same value, which is fine in Triton.

peterbell10 commented 2 months ago

Out of curiosity I profiled the repro before and after the change I do see a small (~1%) speedup that reproduces consistently.

Jokeren commented 2 months ago

Out of curiosity I profiled the repro before and after the change I do see a small (~1%) speedup that reproduces consistently.

I think we need to run internal regression benchmarks instead of external ones