intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
126 stars 36 forks source link

[PyTorch Upstream] triton kernel for "torch.var" backward got wrong result. #838

Closed etaf closed 5 months ago

etaf commented 5 months ago

We got wrong result in Inductor UT when use triton compiled kenrel for torch.var backward. I've build a standalone case for this error. The following case can pass on CUDA(Just replace all the 'xpu' with 'cuda' to run the following script in cuda device).

To reporduce this issue:

clone pytorch from https://github.com/pytorch/pytorch.git build with export USE_XPU=1 Before run case, export PYTORCH_ENABLE_XPU_FALLBACK=1 run the following case:

from ctypes import c_void_p, c_long
import torch
from torch.testing import make_tensor
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()

# kernel path: /tmp/torchinductor_xinanlin/a7/ca7amv5rgi4gnjngf54letv74xv7dgec7oqer3b6tk2ivzbro5vw.py
# Source Nodes: [], Original ATen: [aten.mean, aten.mul, aten.sub]

triton_red_fused_mean_mul_sub_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.reduction(
    size_hints=[1, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {3: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(3,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mean_mul_sub_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '1dfde875e27250dd32acadea6b97bcfcf07e09007a8b2bd1d6370ecadc217070'}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 125
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4 = _tmp3 + tmp2
        _tmp3 = tl.where(rmask, tmp4, _tmp3)
    tmp3 = tl.sum(_tmp3, 1)[:, None]
    tmp5 = tl.load(in_ptr1 + (0)).to(tl.float32)
    tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp9 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = 0.016129032258064516
        tmp8 = tmp6 * tmp7
        tmp10 = 125.0
        tmp11 = tmp3 / tmp10
        tmp12 = tmp11.to(tl.float32)
        tmp13 = tmp9 - tmp12
        tmp14 = tmp8 * tmp13
        tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask)
''', device_str='xpu')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _xpu_getCurrentRawStream as get_raw_stream

async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, tangents_1 = args
    args.clear()
    assert_size_stride(primals_1, (5, 5, 5), (25, 5, 1))
    assert_size_stride(tangents_1, (), ())
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        buf1 = empty_strided((5, 5, 5), (25, 5, 1), device='xpu', dtype=torch.float16)
        # Source Nodes: [], Original ATen: [aten.mean, aten.mul, aten.sub]
        stream0 = get_raw_stream(0)
        triton_red_fused_mean_mul_sub_0.run(primals_1, tangents_1, buf1, 1, 125, grid=grid(1), stream=stream0)
        del primals_1
        del tangents_1
    return buf1

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, primals_1, tangents_1):
        mul = torch.ops.aten.mul.Scalar(tangents_1, 0.016129032258064516);  tangents_1 = None
        mean = torch.ops.aten.mean.default(primals_1)
        sub = torch.ops.aten.sub.Tensor(primals_1, mean);  primals_1 = mean = None
        mul_1 = torch.ops.aten.mul.Tensor(mul, sub);  mul = sub = None
        return mul_1

if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    primals_1 = make_tensor((5, 5, 5), device='xpu:0', dtype=torch.float16)
    tangents_1 = make_tensor((), device='xpu:0', dtype=torch.float16)
    actual = call([primals_1, tangents_1])
    mod = Repro()
    ref = mod(primals_1, tangents_1) 
    print("\n====acutal:\n", actual)
    print("\n====ref:\n", ref)
    assert torch.allclose(actual, ref,  rtol=2e-3, atol=1e-5)
etaf commented 5 months ago

Hi, @vlad-penkin is there any update ? Do you need my help for reproducing the error?

ienkovich commented 5 months ago

The problem here is in different precisions used for actual and reference results computations. The test passes for me if I add a conversion to float32 in the reference model (I add it because all computations in the Titon kernel are executed in float32 and we get back to float16 only on the result store). Here is the modified reference model:

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, primals_1, tangents_1):
        primals_1 = primals_1.to(torch.float32)
        tangents_1 = tangents_1.to(torch.float32)
        mul = torch.ops.aten.mul.Scalar(tangents_1, 0.016129032258064516);  tangents_1 = None
        mean = torch.ops.aten.mean.default(primals_1)
        sub = torch.ops.aten.sub.Tensor(primals_1, mean);  primals_1 = mean = None
        mul_1 = torch.ops.aten.mul.Tensor(mul, sub);  mul = sub = None
        mul_1 = mul_1.to(torch.float16)
        return mul_1
etaf commented 5 months ago

@ienkovich I'm afraid not, The reported issue only happends on fp16, not f32. I've compared the ouput of the reproducer between CUDA and XPU triton. CUDA passes this case, but XPU gets a very different number.

In case you need to run this script in CUDA:

from ctypes import c_void_p, c_long
import torch
from torch.testing import make_tensor
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()

# kernel path: /tmp/torchinductor_xinanlin/a7/ca7amv5rgi4gnjngf54letv74xv7dgec7oqer3b6tk2ivzbro5vw.py
# Source Nodes: [], Original ATen: [aten.mean, aten.mul, aten.sub]

triton_red_fused_mean_mul_sub_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.reduction(
    size_hints=[1, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {3: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(3,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mean_mul_sub_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '1dfde875e27250dd32acadea6b97bcfcf07e09007a8b2bd1d6370ecadc217070'}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 125
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4 = _tmp3 + tmp2
        _tmp3 = tl.where(rmask, tmp4, _tmp3)
    tmp3 = tl.sum(_tmp3, 1)[:, None]
    tmp5 = tl.load(in_ptr1 + (0)).to(tl.float32)
    tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp9 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = 0.016129032258064516
        tmp8 = tmp6 * tmp7
        tmp10 = 125.0
        tmp11 = tmp3 / tmp10
        tmp12 = tmp11.to(tl.float32)
        tmp13 = tmp9 - tmp12
        tmp14 = tmp8 * tmp13
        tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, tangents_1 = args
    args.clear()
    assert_size_stride(primals_1, (5, 5, 5), (25, 5, 1))
    assert_size_stride(tangents_1, (), ())
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf1 = empty_strided((5, 5, 5), (25, 5, 1), device='cuda', dtype=torch.float16)
        # Source Nodes: [], Original ATen: [aten.mean, aten.mul, aten.sub]
        stream0 = get_raw_stream(0)
        triton_red_fused_mean_mul_sub_0.run(primals_1, tangents_1, buf1, 1, 125, grid=grid(1), stream=stream0)
        del primals_1
        del tangents_1
    return buf1

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, primals_1, tangents_1):
        mul = torch.ops.aten.mul.Scalar(tangents_1, 0.016129032258064516);  tangents_1 = None
        mean = torch.ops.aten.mean.default(primals_1)
        sub = torch.ops.aten.sub.Tensor(primals_1, mean);  primals_1 = mean = None
        mul_1 = torch.ops.aten.mul.Tensor(mul, sub);  mul = sub = None
        return mul_1

if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    primals_1 = make_tensor((5, 5, 5), device='cuda:0', dtype=torch.float16)
    tangents_1 = make_tensor((), device='cuda:0', dtype=torch.float16)
    actual = call([primals_1, tangents_1])
    mod = Repro()
    ref = mod(primals_1, tangents_1) 
    print("\n====acutal:\n", actual)
    print("\n====ref:\n", ref)
    assert torch.allclose(actual, ref,  rtol=2e-3, atol=1e-5)
etaf commented 5 months ago

Hi, @ienkovich following is my reproduced result:

CUDA output:

 tensor([[[-0.1647, -0.4136,  0.3462, -0.4700,  0.0137],
         [ 0.0350, -0.2401, -0.3352,  0.2201, -0.1735],
         [-0.4268,  0.1510,  0.1196,  0.3369, -0.3621],
         [ 0.3420, -0.0013, -0.1089,  0.3071, -0.2284],
         [-0.4058,  0.1348,  0.4043, -0.4297,  0.4067]],
...

CPU output:

 tensor([[[-0.1647, -0.4133,  0.3459, -0.4700,  0.0138],
         [ 0.0350, -0.2402, -0.3352,  0.2201, -0.1735],
         [-0.4265,  0.1511,  0.1196,  0.3369, -0.3621],
         [ 0.3420, -0.0013, -0.1089,  0.3071, -0.2284],
         [-0.4058,  0.1348,  0.4043, -0.4297,  0.4065]],

XPU output:

 tensor([[[-0.2322,  0.0094, -0.2374,  0.0491,  0.1621],
         [ 0.2800, -0.3711,  0.3311,  0.3181,  0.2388],
         [-0.2998,  0.1567, -0.3062, -0.0700,  0.1832],
         [-0.3623, -0.2556,  0.1976,  0.1676, -0.0994],
         [-0.2228, -0.2198,  0.3396, -0.2603,  0.3577]],
ienkovich commented 5 months ago

@ienkovich I'm afraid not, The reported issue only happends on fp16, not f32. I've compared the ouput of the reproducer between CUDA and XPU triton. CUDA passes this case, but XPU gets a very different number.

If you look at the generated kernel, you can see that all loaded values are translated to float32. All computations are performed on float32. The result matches your referenced model when it uses data converted to float32. I can suggest, that without explicit input conversion, the reference model performs computations in float16 which doesn't match what we have in the Triton kernel.

Also, if I add CPU results to your original testcase, then I see that CPU results match Triton. So the problem must be in the reference XPU results:

if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    primals_1 = make_tensor((5, 5, 5), device='xpu:0', dtype=torch.float16)
    tangents_1 = make_tensor((), device='xpu:0', dtype=torch.float16)
    actual = call([primals_1, tangents_1])
    mod = Repro()
    ref = mod(primals_1, tangents_1)
    ref_cpu = mod.to('cpu')(primals_1.to('cpu'), tangents_1.to('cpu'))
    print("\n====acutal:\n", actual)
    print("\n====ref:\n", ref)
    print("\n====ref cpu:\n", ref_cpu)

Output:

====acutal:
 tensor([[[-0.0660,  0.2119,  0.0709, -0.8403,  0.1710],
         [-0.2245, -0.2917, -0.2725,  0.6001, -0.7002],
         [-0.0047, -0.3091, -0.4287, -0.2703,  0.8569],
         [-0.5244, -0.0404,  0.3704,  0.8472, -0.7559],
         [ 1.0352, -0.9336, -0.6138,  0.5654, -0.5483]],

        [[ 0.2927, -0.3162,  0.9844, -0.3071,  0.6890],
         [ 1.0439, -0.0220, -0.7188,  0.3213, -0.5015],
         [-0.8149, -0.3459, -0.4236,  0.3909,  0.5610],
         [-0.6587,  0.2334,  0.4500, -0.3049, -0.3896],
         [-0.6709,  0.3069,  0.3704, -0.0302,  0.7832]],

        [[-0.2233, -0.8809,  0.5684,  0.0740,  0.5186],
         [ 0.9272, -0.4551, -0.3796, -0.7656, -0.7749],
         [-0.3162,  0.2037, -0.8506,  0.7300, -0.9990],
         [ 0.0168,  1.0098,  0.4836,  0.4080,  0.4919],
         [ 0.3040,  0.5327, -0.3145, -0.4133, -0.0609]],

        [[ 1.0664,  0.1619, -0.5278, -0.3601,  0.8691],
         [ 1.0117, -0.4053,  0.7720,  0.2467, -0.3552],
         [ 0.0147,  0.8657,  0.0975,  0.2639,  1.0518],
         [-0.4900,  0.7236, -0.8784, -0.2888, -0.8335],
         [-0.9746, -0.7646,  0.0832, -0.3225, -0.5317]],

        [[ 0.7603, -0.4673,  0.1016, -0.5981,  0.1721],
         [ 0.7490, -0.7026,  0.1823,  0.4194, -0.2561],
         [ 0.0188, -0.0343,  0.0311,  0.9751, -0.4236],
         [ 1.0801, -0.1947, -0.6699, -0.6299,  0.4253],
         [ 0.1414,  0.6558, -0.5605,  0.1476, -0.4797]]], device='xpu:0',
       dtype=torch.float16)

====ref:
 tensor([[[-0.1022,  0.1758,  0.0348, -0.8765,  0.1349],
         [-0.2607, -0.3279, -0.3086,  0.5640, -0.7368],
         [-0.0409, -0.3455, -0.4651, -0.3066,  0.8208],
         [-0.5610, -0.0767,  0.3342,  0.8110, -0.7920],
         [ 0.9990, -0.9702, -0.6499,  0.5298, -0.5845]],

        [[ 0.2566, -0.3525,  0.9482, -0.3433,  0.6533],
         [ 1.0078, -0.0583, -0.7554,  0.2852, -0.5376],
         [-0.8511, -0.3821, -0.4600,  0.3547,  0.5249],
         [-0.6948,  0.1973,  0.4138, -0.3413, -0.4260],
         [-0.7070,  0.2708,  0.3342, -0.0664,  0.7471]],

        [[-0.2595, -0.9175,  0.5322,  0.0378,  0.4824],
         [ 0.8911, -0.4915, -0.4160, -0.8022, -0.8110],
         [-0.3525,  0.1676, -0.8872,  0.6938, -1.0352],
         [-0.0194,  0.9736,  0.4475,  0.3721,  0.4558],
         [ 0.2678,  0.4968, -0.3506, -0.4497, -0.0971]],

        [[ 1.0303,  0.1257, -0.5640, -0.3965,  0.8330],
         [ 0.9756, -0.4414,  0.7358,  0.2106, -0.3916],
         [-0.0215,  0.8296,  0.0613,  0.2279,  1.0156],
         [-0.5264,  0.6875, -0.9146, -0.3250, -0.8696],
         [-1.0107, -0.8013,  0.0470, -0.3589, -0.5684]],

        [[ 0.7246, -0.5039,  0.0654, -0.6343,  0.1359],
         [ 0.7129, -0.7388,  0.1461,  0.3833, -0.2922],
         [-0.0174, -0.0705, -0.0051,  0.9395, -0.4600],
         [ 1.0449, -0.2310, -0.7061, -0.6665,  0.3892],
         [ 0.1053,  0.6196, -0.5967,  0.1114, -0.5161]]], device='xpu:0',
       dtype=torch.float16)

====ref cpu:
 tensor([[[-0.0660,  0.2120,  0.0710, -0.8403,  0.1711],
         [-0.2245, -0.2917, -0.2725,  0.6006, -0.7002],
         [-0.0047, -0.3093, -0.4290, -0.2705,  0.8569],
         [-0.5244, -0.0405,  0.3704,  0.8477, -0.7559],
         [ 1.0352, -0.9336, -0.6138,  0.5659, -0.5483]],

        [[ 0.2927, -0.3164,  0.9849, -0.3074,  0.6895],
         [ 1.0449, -0.0220, -0.7188,  0.3213, -0.5015],
         [-0.8149, -0.3462, -0.4238,  0.3909,  0.5615],
         [-0.6587,  0.2334,  0.4500, -0.3052, -0.3899],
         [-0.6709,  0.3069,  0.3704, -0.0302,  0.7837]],

        [[-0.2234, -0.8813,  0.5688,  0.0740,  0.5186],
         [ 0.9272, -0.4553, -0.3799, -0.7656, -0.7749],
         [-0.3164,  0.2039, -0.8506,  0.7305, -0.9990],
         [ 0.0168,  1.0098,  0.4836,  0.4082,  0.4919],
         [ 0.3040,  0.5332, -0.3145, -0.4136, -0.0609]],

        [[ 1.0664,  0.1619, -0.5278, -0.3604,  0.8691],
         [ 1.0117, -0.4053,  0.7720,  0.2466, -0.3555],
         [ 0.0147,  0.8662,  0.0975,  0.2642,  1.0518],
         [-0.4900,  0.7241, -0.8784, -0.2888, -0.8335],
         [-0.9746, -0.7646,  0.0833, -0.3228, -0.5317]],

        [[ 0.7607, -0.4673,  0.1016, -0.5981,  0.1721],
         [ 0.7495, -0.7026,  0.1824,  0.4194, -0.2561],
         [ 0.0188, -0.0343,  0.0311,  0.9756, -0.4238],
         [ 1.0811, -0.1947, -0.6699, -0.6299,  0.4253],
         [ 0.1415,  0.6558, -0.5605,  0.1476, -0.4797]]], dtype=torch.float16)
etaf commented 5 months ago

@ienkovich thanks for you kindly explain, I'm checking the xpu_ref vs cpu_ref diff.

etaf commented 5 months ago

Hi, @ienkovich I've rootcased that the problem is in the reference XPU results. Sorry for opening this issue before check the XPU reference. I'll take care of that before I submit a new issue next time.