apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.59k stars 3.44k forks source link

[Bug][Pytorch FX] Can't load quantized linear layer to Relay #15155

Open fengzi0 opened 1 year ago

fengzi0 commented 1 year ago

When i load Pytorch fx quantized model to TVM like below code:

import torch
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qat_qconfig_mapping, quantize_fx
import tvm
from tvm import relay

qconfig_mapping = get_default_qconfig_mapping()
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(256, 256)
    def forward(self, input):
        x = self.linear(input)
        return x

mm = MyModule()

input = torch.randn((1,900,256))
mm_prepared = quantize_fx.prepare_fx(mm, qconfig_mapping, (input))
r = mm_prepared(input)
mm_quantize = quantize_fx.convert_fx(mm_prepared)

script_mm = torch.jit.trace(mm_quantize, (input))
input_shapes_mm = [('input', tuple(input.shape))]
mod, params = relay.frontend.from_pytorch(script_mm, input_shapes_mm)

this give:

The Relay type checker is unable to show the following types match:
  Tensor[(900), int32]
  Tensor[(256), int32]
In particular:
  dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), int32]` does not match `Tensor[(900), int32]`
The Relay type checker is unable to show the following types match:
  Tensor[(900), float32]
  Tensor[(256), float32]
In particular:
  dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), float32]` does not match `Tensor[(900), float32]`
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 4649, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 4025, in convert_operators
    self.record_output_type(relay_out)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 220, in record_output_type
    self.infer_type_with_prelude(output)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 168, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 161, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm.error.DiagnosticError: Traceback (most recent call last):
  6: TVMFuncCall
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  4: tvm::transform::Pass::operator()(tvm::IRModule) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::DiagnosticContext::Render()
  File "/workspace/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

version info:

>>> torch.__version__
'2.0.0+cu117'
>>> tvm.__version__
'0.11.1'

Wondering is there something obvious that I should fix? Thanks!

masahi commented 1 year ago

For a workaround, you can remove the batch dimension in the input.