pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.48k stars 149 forks source link

[Bug] ERR: subclass doesn't implement <function multi_head_attention_forward> #1103

Open dgcnz opened 1 week ago

dgcnz commented 1 week ago

Description

torchao.autoquant fails when running it with a model that uses torch.nn.MultiheadAttention. I think it might be related to https://github.com/pytorch/pytorch/issues/72186#issuecomment-1028326358.

Minimum reproducible program

import torch
import torchao

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.mha = torch.nn.MultiheadAttention(128, 4)
    def forward(self, x):
        x, _ = self.mha(x, x, x)
        return x

model = Model()
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
inputs = (torch.randn(10, 32, 128),)
out = model(*inputs)

Error Logs

ERR: subclass doesn't implement <function multi_head_attention_forward at 0x7a74c74148b0>
Traceback (most recent call last):
  File "$HOME/.conda/envs/cu124/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "$HOME/.conda/envs/cu124/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "$WORKSPACE/scripts/mre/mha_autoquant.py", line 16, in <module>
    out = model(*inputs)
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
    return inner()
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1769, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 720, in autoquant_prehook
    real_model.forward(*args, **kwargs)
  File "$WORKSPACE/scripts/mre/mha_autoquant.py", line 10, in forward
    x, _ = self.mha(x, x, x)
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "$HOME/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1368, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
TypeError: cannot unpack non-iterable NoneType object

Environment

(cu124) [dgcnz@dpc edge]$ pip freeze | grep torch
torch==2.6.0.dev20241013+cu124
torchao==0.7.0.dev20241017+cu124
torchvision==0.20.0.dev20241013+cu124
dgcnz commented 1 week ago

Should've looked at the PRs, it seems it's already being fixed at https://github.com/pytorch/ao/pull/977. Leaving this issue open in case anyone stumbles on the same error. Feel free to close it.

StephenChou0119 commented 3 days ago

same problem. image

jerryzh168 commented 3 days ago

yeah, we are fixing this here: https://github.com/pytorch/ao/pull/1141