pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
725 stars 92 forks source link

[BUG] Float8Linear does not work with torch.inference_mode #643

Open leeeizhang opened 1 month ago

leeeizhang commented 1 month ago

FP8 Linear does not work for me:

  • torch == 2.4.0 + cu121
  • torchao == 0.4.0
  • cuda_arch == 8.9 (nvidia L40)
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

class FFN(nn.Module):
    def __init__(self, in_feature, hidden_feature, bias=True):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
        self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x

bs, seq, dim = 32, 512, 1024

m = FFN(dim, dim * 4).cuda()
convert_to_float8_training(m)
# m = torch.compile(m)

x = torch.randn((bs, seq, dim), device="cuda")

with torch.inference_mode(mode=True):
    y = m(x)
/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
  File "/root/erdos/ops/triton/t.py", line 28, in <module>
    y = m(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/erdos/ops/triton/t.py", line 14, in forward
    x = self.fc1(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
    output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 59, in forward
    input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_tensor.py", line 360, in __torch_dispatch__
    raise NotImplementedError(f"attempting to run {func}, this is not supported")
NotImplementedError: attempting to run aten.reshape.default, this is not supported
leeeizhang commented 1 month ago

It seems like FP8Linear could not run on inference mode. I have removed the torch.inference_mode(), but it still not works:

/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
  File "/root/erdos/ops/triton/t.py", line 28, in <module>
    y = m(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/erdos/ops/triton/t.py", line 14, in forward
    x = self.fc1(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
    output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 60, in forward
    res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],
        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],
        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],
        ...,
        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],
        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],
        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],
       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor
supriyar commented 1 month ago

cc @vkuzo, @drisspg

qingquansong commented 4 weeks ago

Hey @leeeizhang I'm also facing the same issue below even with the latest changes and seems not related to reshaping or mode. Did this issue got resolved on your side after the fix? Thank you! (I'm using torch 2.3.1 btw). The returned tensor tuple is coming from here. ( emulate=True does not have problem btw) It seems for all the output type cases, it will return a tuple of two tensors (for fp16/fp32/bf16, it will return a tuple with second tensor to be 0 scalar and for fp8 output type it will return the scale ) For for all the cases, it cannot be treated as a single tensor and casted. 🤔 This is a poor man version fix: https://github.com/pytorch/ao/pull/702 (need a better solution)

RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],
        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],
        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],
        ...,
        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],
        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],
        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],
       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor
leeeizhang commented 4 weeks ago

Hey @leeeizhang I'm also facing the same issue below even with the latest changes and seems not related to reshaping or mode. Did this issue got resolved on your side after the fix? Thank you! (I'm using torch 2.3.1 btw). The returned tensor tuple is coming from here. ( emulate=True does not have problem btw) It seems for all the output type cases, it will return a tuple of two tensors (for fp16/fp32/bf16, it will return a tuple with second tensor to be 0 scalar and for fp8 output type it will return the scale ) For for all the cases, it cannot be treated as a single tensor and casted. 🤔 This is a poor man version fix: https://github.com/pytorch/ao/pull/702 (need a better solution)


RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],

        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],

        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],

        ...,

        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],

        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],

        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],

       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor

Try the torch nightly (2.5.0dev), which refactor the returns of torch._scaled_mm into tensor instead of tuple.

qingquansong commented 3 weeks ago

@leeeizhang Thank you very much!

vkuzo commented 3 weeks ago

Thanks for filing, I think we should make the version expectations clear in the readme, reopening until we make that happen.