pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.22k stars 21.56k forks source link

can not convert to onnx after QAT #121893

Closed misaka-1 closed 2 months ago

misaka-1 commented 3 months ago

πŸ› Describe the bug

i try fuse model before QAT, and some errors occur. """ Swift Parameter-free Attention Network for Efficient Super-Resolution """ ` class SPAN(nn.Module):

def __init__(self,
             num_in_ch,
             num_out_ch,
             feature_channels=48,
             upscale=4,
             bias=True
             ):
    super(SPAN, self).__init__()

    self.in_channels = num_in_ch
    self.out_channels = num_out_ch

    self.conv_1 = Conv3XC(self.in_channels, feature_channels, gain1=2, s=1)
    self.block_1 = SPAB(feature_channels, bias=bias)
    self.block_2 = SPAB(feature_channels, bias=bias)
    self.block_3 = SPAB(feature_channels, bias=bias)
    self.block_4 = SPAB(feature_channels, bias=bias)
    self.block_5 = SPAB(feature_channels, bias=bias)
    self.block_6 = SPAB(feature_channels, bias=bias)

    self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True)
    self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1)
    self.end = conv_layer(feature_channels, self.out_channels * (upscale ** 2), kernel_size=3)
    self.upsampler = nn.PixelShuffle(upscale)
    # self.upsampler = pixelshuffle_block(feature_channels, self.out_channels, upscale_factor=upscale)
    self.quant = torch.quantization.QuantStub()
    self.dequant = torch.quantization.DeQuantStub()
    self.f_cat = nn.quantized.FloatFunctional()

def forward(self, x):
    x = self.quant(x)

    out_feature = self.conv_1(x)

    out_b1, out_feature_2 = self.block_1(out_feature)
    out_b2, out_b1_2 = self.block_2(out_b1)
    out_b3, out_b2_2 = self.block_3(out_b2)

    out_b4, out_b3_2 = self.block_4(out_b3)
    out_b5, out_b4_2 = self.block_5(out_b4)
    out_b6, out_b5_2 = self.block_6(out_b5)

    out_b6_2 = self.conv_2(out_b6)
    out = self.conv_cat(self.f_cat.cat([out_feature, out_b6_2, out_b1, out_b5_2], 1))
    out = self.end(out)
    output = self.upsampler(out)
    output = torch.clamp(output, min=0.0, max=255.0)
    output = self.dequant(output)

    return output

def fuse_model(self):
    for name, module in self.named_children():
        if isinstance(module, SPAB):
            for n, m in module.named_children():
                if isinstance(m, Conv3XC):
                    RK, RB = m.rep_param()
                    conv = Conv3XC_QAT(m.eval_conv.in_channels, m.eval_conv.out_channels, m.eval_conv.kernel_size,
                                       m.eval_conv.stride, m.eval_conv.padding)
                    conv.block.weight.data = RK
                    conv.block.bias.data = RB
                    setattr(module, n, conv)
        elif isinstance(module, Conv3XC):
            RK, RB = module.rep_param()
            conv = Conv3XC_QAT(module.eval_conv.in_channels, module.eval_conv.out_channels,
                               module.eval_conv.kernel_size,
                               module.eval_conv.stride, module.eval_conv.padding)
            conv.block.weight.data = RK
            conv.block.bias.data = RB
            setattr(self, name, conv)

model = SPAN(1, 1, upscale=2, feature_channels=48)
# model.eval()
inputs = torch.randn(1, 1, 16, 16)

model.fuse_model()
backend = 'fbgemm'
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)

# sr_inputs = (torch.rand(1, 1, 256, 256).cuda(),)
sr = model(inputs)
model.eval()
saved_model = torch.quantization.convert(model, inplace=False)
# int_sr = saved_model(inputs)
# Convert to TorchScript
# scripted_model = torch.jit.script(saved_model)
onnx_filename = "imageEn-SPAN-test.onnx"
dynamic_axes = {'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
torch.onnx.export(saved_model,
                  inputs,
                  onnx_filename,
                  input_names=['input'],
                  output_names=['output'],
                  verbose=True, 
                  do_constant_folding=True,
                  dynamic_axes=dynamic_axes,
                  export_params=True,
                  opset_version=13) `

In step sr = model(inputs), x size is (1, 1, 16, 16). After converting and export to onnx, x size is Tensor: (tensor(1), tensor(1), tensor(16), tensor(16)

Versions

python version: Python 3.9.13 pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

thiagocrepaldi commented 2 months ago

torch.onnx.export does not support QAT and there is no plan to support it

Please try the new ONNX exporter and reopen this issue if it also doesn't work for you: quick torch.onnx.dynamo_export API tutorial