alibaba / TinyNeuralNetwork

TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework.
MIT License
738 stars 117 forks source link

ViT PTQ Error #343

Closed hoangtv2000 closed 2 months ago

hoangtv2000 commented 2 months ago

Hi @zk1998, I have an another issue with your repository. I implemented PTQ on ViT by following the code in vit_post.py and got error below:

[W lower_tuples.cpp:253] Warning: tuple appears in the op outputs, but this op does not forward tuples, unsupported kind: prim::CallMethod (function flattenOutputs)
Traceback (most recent call last):
  File "TinyNeuralNet/examples/quantization/specific/vit/vit_post.py", line 188, in <module>
    ptq_vit(args)
  File "TinyNeuralNet/examples/quantization/specific/vit/vit_post.py", line 178, in ptq_vit
    converter.convert()
  File "/data/hoangtv23/workspace_AIOT/model_compression_flow/TinyNeuralNet/examples/quantization/specific/vit/../../../../tinynn/converter/base.py", line 513, in convert
    self.init_lowered_module()
  File "/data/hoangtv23/workspace_AIOT/model_compression_flow/TinyNeuralNet/examples/quantization/specific/vit/../../../../tinynn/converter/base.py", line 301, in init_lowered_module
    torch._C._jit_pass_lower_all_tuples(graph)
RuntimeError: prim::TupleUnpack not matched to tuple construct

Here is the part of TorchScript Graph of model, I think the problem is come from ' prim::TupleUnpack', but I don't know how to fix it.

...
  %634 : Tensor = prim::CallMethod[name="forward"](%fake_quant_0.1, %input_0_f.1) # :0:0
  %637 : Tensor = prim::CallMethod[name="forward"](%fake_quant_1.1, %vit_vit_embeddings_cls_token.1) # :0:0
  %639 : (Tensor, Tensor?) = prim::CallMethod[name="unpack"](%_packed_params.1) # :0:0
  %Xq.1 : Tensor, %641 : Tensor? = prim::TupleUnpack(%639)
  %644 : NoneType = prim::CallMethod[name="forward"](%fake_quant_3.1, %Xq.1) # :0:0
  %647 : int = aten::size(%634, %646) # TinyNeuralNet/examples/quantization/specific/vit/out/vitwrapper_q.py:351:0
  %648 : Tensor = prim::NumToTensor(%647) # :0:0
  %651 : int = aten::Int(%648) # :0:0
...

Can you show me how to resolve this error. Thank you so much.

peterjc123 commented 2 months ago

Hi, I reproduced the issue and figured out that it is some extra operations on the quantized weights that leads to the error, which seems to be caused by the PyTorch codebase. Luckily, we may workaround this problem by adding the following keyword parameter QATQuantizer(..., config={...., 'extra_tracer_opts': {'eliminate_dead_graph': True},} so that those operations will be removed. (P.S. There are multiple QATQuantizer instances and I think you should apply it to all the instances)

peterjc123 commented 2 months ago

You may refer to #344 as well if I didn't make it clear in the previous comment.