intel / torch-xpu-ops

Apache License 2.0
30 stars 21 forks source link

glu_backward fp16 has accuracy issues compared with CPU output. #1054

Closed LuFinch closed 1 week ago

LuFinch commented 2 weeks ago

🐛 Describe the bug

Reproducer

import torch
import torch.nn as nn
import torch.autograd.forward_ad as fwAD
from torch.testing._internal.common_utils import TestCase
from typing import Tuple
from torch.overrides import is_tensor_like
from itertools import product

cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")

class TestNNMethod(TestCase):
    def test_glu(self):
        for dt in [torch.float16]:
            input_cpu = torch.randn(4, 6, dtype=dt)
            input_dpcpp = input_cpu.to("xpu")
            m = nn.GLU()

            input_cpu.requires_grad = True
            output_cpu = m(input_cpu)
            output_cpu.backward(torch.ones_like(output_cpu).to(dt))

            input_dpcpp.requires_grad = True
            output_dpcpp = m(input_dpcpp)
            output_dpcpp.backward(torch.ones_like(output_dpcpp).to(dt).to("xpu"))
            self.assertEqual(input_cpu.grad, input_dpcpp.grad)

Directly run pytest test.py and it would output

>           self.assertEqual(input_cpu.grad, input_dpcpp.grad)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1 / 24 (4.2%)
E           Greatest absolute difference: 0.000244140625 at index (1, 3) (up to 1e-05 allowed)
E           Greatest relative difference: 0.0010671615600585938 at index (1, 3) (up to 0.001 allowed)

Status

This gap between xpu and cpu causes an IPEX UT fail. In IPEX2.5, we override this Op with IPEX implementation.

However, I found that implementation of glu_backward in torch-xpu-ops is aligned with CUDA, and CUDA can't pass this case too.

Not sure whether we should fix or just skip this UT.

Versions

...

daisyden commented 1 week ago

cuda and xpu does not use accumulate dtype for bfloat16 and flat16, while cpu used. torch-xpu-ops will align with cuda. cuda: void glu_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { using opmath_t = at::opmath_type; gpu_kernel(iter, [] GPU_LAMBDA(scalart a, scalart b) -> scalar_t { const opmatht a = a; const opmatht b = b; const opmath_t one = opmath_t(1); const opmath_t sigmoid = one / (one + std::exp(-b)); return a * sigmoid; }); }); }

xpu: struct GluFunctor { using opmath_t = at::opmath_type; scalar_t operator()(scalart a, scalart b) const { const opmatht a = a; const opmatht b = b; const opmath_t one = opmath_t(1); const opmath_t sigmoid = one / (one + std::exp(-b)); return a * sigmoid; } };

cpu: void glu_kernel(TensorIteratorBase& iter) { if (at::isReducedFloatingType(iter.dtype())) { AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&]() { const float float_one_val(1); const Vectorized float_one_vec(float_one_val); cpu_kernel_vec( iter, [float_one_val](scalar_t a, scalar_t b) -> scalar_t { return float(a) (float_one_val / (float_one_val + std::exp(- float(b)))); }, [float_one_vec](Vectorized a, Vectorized b) -> Vectorized { auto [a0, a1] = convert_to_float(a); auto [b0, b1] = convert_to_float(b); return convert_from_float(a0 (float_one_vec / (float_one_vec + b0.neg().exp())), a1 * (float_one_vec / (float_one_vec + b1.neg().exp()))); }); }); } else { ...

daisyden commented 1 week ago

close the issue as torch-xpu-ops will align with cuda.