pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 166 forks source link

fpx test failing on main #792

Closed msaroufim closed 1 month ago

msaroufim commented 2 months ago

This started failing likely because of some breaking change in upstream cpu inductor code

python test/dtypes/test_fpx.py TestFpxTensorCoreAQTLayout.test_to_scaled_tc_fpx_compile_ebits_3_mbits_2_device_cpu

See these logs https://github.com/pytorch/ao/actions/runs/10668284826/job/29567645270?pr=748

gau-nernst commented 2 months ago

Actually we never compile to_scaled_tc_fpx() anywhere, though things might change in the future. So skipping/removing that test is also possible. But I think this likely indicates a regression from upstream. Will try to isolate the issue.

gau-nernst commented 2 months ago

The offending code comes from _f32_to_fpx_unpacked()

This will give the same error

import torch
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked

torch.compile(_f32_to_fpx_unpacked)(torch.randn(128), 3, 2)

The line that first causes the error seems to be this line

https://github.com/pytorch/ao/blob/e15e50987de92b891118aeb79a99a968205891fb/torchao/prototype/custom_fp_utils.py#L113

Comparing the codegen between 2.4 and nightly, this is the difference (note that MBITS_F32 - mbits = 21)

// 2.4
auto tmp16 = static_cast<int32_t>(21);
auto tmp17 = decltype(tmp4)(tmp4 >> tmp16)

// nightly
auto tmp22 = static_cast<int32_t>(21);
auto tmp23 = c10::convert<float>(tmp22);
auto tmp24 = at::vec::Vectorized<float>(tmp23);
auto tmp25 = tmp6 >> tmp24;

I'm suspecting c10::convert<float> and at::vec::Vectorized<float> are causing issues, since we can't do bit-shift on float. Not sure why this is happening.

Update: I can re-produce it with a minimal example

import torch

def f(x):
    return x.view(torch.int32) >> 2

torch.compile(f)(torch.ones(16, 16))

If the input is int32, the codegen is correct -> no error. I think we can raise this in core?

gau-nernst commented 2 months ago

Another test is also failing. See https://github.com/pytorch/ao/actions/runs/10670269945/job/29573978093

test/dtypes/test_fpx.py::TestFpxTensorCoreAQTLayout::test_from_scaled_tc_fpx_compile_ebits_3_mbits_2_device_cpu

Seems like rounding errors.

jerryzh168 commented 1 month ago

this is fixed now?

gau-nernst commented 1 month ago

Yes, this is fixed. Can re-open the issue if that's not the case.