quic / aimet

AIMET is a library that provides advanced quantization and compression techniques for trained neural network models.
https://quic.github.io/aimet-pages/index.html
Other
2.11k stars 377 forks source link

KeyError when I export the QuantizationSimModel to torchscript format #2461

Open xiexiaozheng opened 1 year ago

xiexiaozheng commented 1 year ago

aimet version: 1.28 pytorch version:1.13.1 My model contains subtraction operations, and when I export it in the TorchScript format,

onnx_export_args = aimet_torch.onnx_utils.OnnxExportApiArgs(opset_version=11,)
quantsim.export(path="/data/segmentation/MaskFormer/quantize/exported_model/W8A8/aimet/", filename_prefix='QAT_model', dummy_input=dummy_input, onnx_export_args=onnx_export_args, export_to_torchscript=True, propagate_encodings=True)

the program encounters an error, the error was from the aimet_torch/torchscript_utils.py specific log is

File "/data/anaconda3/envs/aimet1.28/lib/python3.8/site-packages/aimet_torch/quantsim.py", line 428, in export
    self.export_torch_script_model_and_encodings(path, filename_prefix, model_to_export, self.model,
  File "/data/anaconda3/envs/aimet1.28/lib/python3.8/site-packages/aimet_torch/quantsim.py", line 460, in export_torch_script_model_and_encodings
    torchscript_utils.get_node_to_io_tensor_names_map(original_model, trace, dummy_input)
  File "/data/anaconda3/envs/aimet1.28/lib/python3.8/site-packages/aimet_torch/torchscript_utils.py", line 243, in get_node_to_io_tensor_names_map
    assert op_type_map[type(node.module)] == node.node_type
KeyError: <class 'aimet_torch.elementwise_ops.Subtract'>
quic-klhsieh commented 1 year ago

@xiexiaozheng , thanks for bringing this up. Up until now, we have mainly been focusing on ensuring the onnx export flow is robust, and the torchscript export flow has not been well maintained.

To understand the context behind this issue, we use op_type_map from torchscript_utils.py in the course of torchscript export, and this dictionary maps torch modules to whatever string torch.jit.trace uses to represent that module type.

The logic within get_node_to_io_tensor_names_map() tries to line up modules encountered in a forward pass of the torch model 1:1 with nodes which are seen in the torch.jit.trace. If a module exists in the model for which there is no mapping in op_type_map, the 1:1 matching goes out of sync.

So in general, for a model to be exportable via torchscript, op_type_map will need to have a mapping for every torch module type in the model. Would you be able to survey your model and check which module types are not captured in op_type_map? And are you familiar with torch.jit.trace to be able to traverse the trace graph and see what strings torch trace uses to express different module types?