ROCm / triton

Development repository for the Triton language and compiler
MIT License
92 stars 29 forks source link

[PYTORCH] tl.reduce error (warp_size=2): couldn't allocate output register for constraint 'v' #367

Closed jataylo closed 1 year ago

jataylo commented 1 year ago

@jayfurmanek @zhanglx13 @alefimov-amd cc: @jithunnair-amd @dllehr-amd

This error is observed with TOT triton-mlir and causes some huggingface models to break along with some PyTorch UTs. Likely to be a blocking issue on us moving PyTorch commit forward.

Experiments: I have found this seems to occur specifically with num_warps=2 if we change the reproducer triton.compile line to num_warps=4 then the error is not observed: test = triton.compile(triton_fn, signature="*i1,*fp32,*i1,i32,i32", constants={"XBLOCK": 1, "RBLOCK": 128}, num_warps=4)

Full parameters (failing):

python triton_repro.py 
is_cuda: False
is_hip: True
warp_size: 64
context: <triton._C.libtriton.triton.ir.context object at 0x7f0425eb9a70>
constants: {'XBLOCK': 1, 'RBLOCK': 128}
num_warps: 2
num_ctas: 1
num_stages: 2
waves_per_eu: 0
enable_warp_specialization: False
enable_persistent: False
extern_libs: {}
debug: False
optimize_epilogue: False
error: couldn't allocate output register for constraint 'v'

Reproducer:

import torch
from torch import empty_strided, device

import triton
import triton.language as tl
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

@triton.jit
def _any_combine(a, b):
    return a | b

@triton.jit
def tl_any(a, dim):
    return tl.reduce(a, dim, _any_combine)

@triton.jit
def triton_fn(in_out_ptr0, in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.int1)
    _tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.int1)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0)
        tmp1 = tl.math.isinf(tmp0).to(tl.int1)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4 = _tmp3 | tmp2
        _tmp3 = tl.where(rmask, tmp4, _tmp3)
        tmp5 = tmp1 == 0
        tmp6 = tmp5 == 0
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp9 = _tmp8 | tmp7
        _tmp8 = tl.where(rmask, tmp9, _tmp8)
    tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
    tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
    tmp8 = tl_any(_tmp8.to(tl.int8), 1)[:, None].to(tl.int1)
    tmp10 = tmp8 == 0
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None)

from torch._dynamo.testing import rand_strided
arg2_1 = rand_strided((16, 8), (8, 1), device='cuda:0', dtype=torch.float32)
buf0 = empty_strided((16, ), (1, ), device='cuda', dtype=torch.bool)
stream0 = get_cuda_stream(0)
buf1 = empty_strided((), (), device='cuda', dtype=torch.bool)
buf4 = empty_strided((), (), device='cuda', dtype=torch.bool)

# Source Nodes: [all_2, any_2, isinf, isinf_2, logical_not], Original ATen: [aten.all, aten.any, aten.isinf, aten.logical_not]
test = triton.compile(triton_fn, signature="*i1,*fp32,*i1,i32,i32", constants={"XBLOCK": 1, "RBLOCK": 128}, num_warps=2)

Note: @jayfurmanek mentioned this commit previously fixed a similar error https://github.com/ROCmSoftwarePlatform/triton/pull/347/commits/a41f13adcd4ae623d69f77290008fa1d7b10b490

This may be a potentially relevant PyTorch PR https://github.com/pytorch/pytorch/commit/94d306fd4594a642ac70f3f5d2a9eb4910587774 which goes over some of the accumulator logic for aten.any, but this PR is pretty old at this point.

jayfurmanek commented 1 year ago

It does look like int8 again:

tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)

and then casting to int1?

jayfurmanek commented 1 year ago

The previous fix promoted to int32 since its free and more correct anyway -but maybe there are gaps in that fix...

jataylo commented 1 year ago

It does look like int8 again:

tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)

and then casting to int1?

Yeah the upstream comment mentions they cast to tl.int8 as tl.reduce doesn't support tl.int1 and then cast back to tl.int1 afterwards https://github.com/pytorch/pytorch/commit/94d306fd4594a642ac70f3f5d2a9eb4910587774#diff-ea3834aeda2e22f49dd5ba0becd456b449b10d17dfe64c34e89072701c82873dR1285

jayfurmanek commented 1 year ago

It does look like int8 again: tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1) and then casting to int1?

Yeah the upstream comment mentions they cast to tl.int8 as tl.reduce doesn't support tl.int1 and then cast back to tl.int1 afterwards pytorch/pytorch@94d306f#diff-ea3834aeda2e22f49dd5ba0becd456b449b10d17dfe64c34e89072701c82873dR1285

Ah ok that makes sense