Xilinx / Vitis-AI

Vitis AI is Xilinx’s development stack for AI inference on Xilinx hardware platforms, including both edge devices and Alveo cards.
https://www.xilinx.com/ai
Apache License 2.0
1.49k stars 630 forks source link

AddScalar() does not work for quantization #1442

Open HyberionBrew opened 5 months ago

HyberionBrew commented 5 months ago

When utilizing the AddScalar Functional, replacing normal adds in order to allow for quantization, an error is returned. See below for the concrete error. Minimum example to reproduce: This is with torch 2.0.0. Vitis-Ai 3.5.


import pytorch_nndct.nn.modules.functional as q_functional
import torch
import torch.nn as nn
from pytorch_nndct import nn as nndct_nn

from pytorch_nndct import nn as nndct_nn
from pytorch_nndct.nn.modules import functional
from pytorch_nndct import QatProcessor

class TestAddScalar(nn.Module):
    def __init__(self):
        super(TestAddScalar, self).__init__()
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
        # add fc layer
        self.add_scalar = q_functional.AddScalar()
        # self.constant = q_functional.Const(1e-12 )
        self.clamp = q_functional.Clamp()
        self.fc = nn.Linear(6, 1)
    def forward(self, x):
        x = self.quant_stub(x)
        x = self.add_scalar(x, 1e-12)
        x = self.fc(x)
        x = self.dequant_stub(x)
        return x 

if __name__ == "__main__":
    model = TestAddScalar()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    inputs = torch.randn( (10, 6) ,dtype=torch.float32, device=device)
    qat_processor = QatProcessor(model, inputs)
    quantized_model = qat_processor.trainable_model(allow_reused_module=True)
    deployable_model = qat_processor.to_deployable(quantized_model,
                                                "/workspace/")
[VAIQ_ERROR][QUANTIZER_TORCH_NOT_A_MODULE]: Quantized operation(TestAddScalar::TestAddScalar/AddScalar[add_scalar]/ret.5) must be instance of "torch.nn.Module", please replace torch.add/+ with <class 'pytorch_nndct.nn.modules.functional.Add'>.The original source range is:
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/_tensor.py(1295): __torch_function__
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/override_torch_function.py(52): __torch_function__
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/nn/modules/functional.py(62): forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1538): _call_impl
minimum_example_add_scalar.py(24): forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1538): _call_impl
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/rich_in_out_helper.py(202): forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1538): _call_impl
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/jit/_trace.py(118): wrapper
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/jit/_trace.py(127): forward
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py(1501): _call_impl
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/torch/jit/_trace.py(1268): _get_trace_graph
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/utils/jit_utils.py(403): trace_and_get_graph_from_model
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/trace_helper.py(130): _trace_graph_from_model
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/trace_helper.py(75): _get_fw_graph_from_module
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/trace_helper.py(104): build_torch_graph
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/parse/parser.py(78): __call__
/opt/vitis_ai/conda/envs/vitis-ai-pytorch-2.0.0/lib/python3.8/site-packages/pytorch_nndct/quantization/quant_aware_training.py(206): __init__
minimum_example_add_scalar.py(36): <module>