apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.75k stars 3.47k forks source link

PyTorch quantized group conv2d raise an error when converting to tvm #7878

Closed eleflea closed 3 years ago

eleflea commented 3 years ago

Hi I find a bug of relay. I used pytorch to quantize a grouped convolution model, and an error was reported in relay.frontend.from_pytorch function. Detailed as follows.

import tvm
from tvm import relay
import torch
from torch import nn
from torch import quantization

GROUPS = 4

class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()
        self.gconv = nn.Conv2d(12, 24, 3, groups=GROUPS, bias=False)

    def forward(self, x):
        x = self.quant(x)
        return self.dequant(self.gconv(x))

net = Net()

net.eval()
net.qconfig = torch.quantization.get_default_qconfig('fbgemm')
net = torch.quantization.prepare(net, inplace=False)
net = torch.quantization.convert(net, inplace=False)

inp = torch.randn(1, 12, 32, 32)
script_module = torch.jit.trace(net, inp).eval()

input_name = "input"  # the input name can be be arbitrary for PyTorch frontend.
input_shapes = [(input_name, (1, 12, 32, 32))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

target = tvm.target.cuda()
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build_module.build(mod, target=target, params=params)
print('finish')

It raise:

The Relay type checker is unable to show the following types match.
In particular dimension 0 conflicts: 72 does not match 24.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(24), float32]` does not match `Tensor[(72), float32]`
Traceback (most recent call last):
  File "bug.py", line 33, in <module>
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 3238, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 2662, in convert_operators
    self.record_output_type(relay_out)
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 222, in record_output_type
    self.infer_type_with_prelude(output)
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 170, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 163, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm.error.DiagnosticError: Traceback (most recent call last):
  [bt] (6) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(TVMFuncCall+0x5b) [0x7fd44602622b]
  [bt] (5) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x80b06a) [0x7fd4454b306a]
  [bt] (4) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule) const+0xcd) [0x7fd4454b24ad]
  [bt] (3) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1b7) [0x7fd4454b1c27]
  [bt] (2) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x1195018) [0x7fd445e3d018]
  [bt] (1) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::DiagnosticContext::Render()+0x199) [0x7fd44545ecc9]
  [bt] (0) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x7b5d22) [0x7fd44545dd22]
  File "/home/eleflea/code/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

But when GROUPS = 1, it works, so i think it is related to quantized group conv2d. Looking forward to your reply! Thank you.

masahi commented 3 years ago

@anijain2305 Is group conv supported by QNN? The following line that does multiplication seems incorrect for group conv case. Here, the weight shape is (24, 3, 3, 3), and multiplying 24 * 3 results in the error message above because weight scale shape is (24,). https://github.com/apache/tvm/blob/813136401a11a49d6c15e6013c34dd822a5c4ff6/src/relay/qnn/op/convolution.cc#L81

@eleflea For now you can do per tensor quantization to workaround this problem (the error happens if you use per channel weight quantization by get_default_qconfig('fbgemm')). You can force per tensor Q by get_default_qconfig('qnnpack') for example.

tqchen commented 3 years ago

Thanks for asking the question, the community uses for trouble shooting and discussions, please bring a new discussion topic on https://discuss.tvm.apache.org/, where more people will be able to watch and answer the questions.