hongyuanyu / SPAN

Swift Parameter-free Attention Network for Efficient Super-Resolution
Apache License 2.0
115 stars 6 forks source link

can not convert to onnx after QAT #7

Open misaka-1 opened 5 months ago

misaka-1 commented 5 months ago

i try fuse model before QAT, and some errors occur. `

@ARCH_REGISTRY.register()

class SPAN(nn.Module): """ Swift Parameter-free Attention Network for Efficient Super-Resolution """

def __init__(self,
             num_in_ch,
             num_out_ch,
             feature_channels=48,
             upscale=4,
             bias=True
             ):
    super(SPAN, self).__init__()

    self.in_channels = num_in_ch
    self.out_channels = num_out_ch

    self.conv_1 = Conv3XC(self.in_channels, feature_channels, gain1=2, s=1)
    self.block_1 = SPAB(feature_channels, bias=bias)
    self.block_2 = SPAB(feature_channels, bias=bias)
    self.block_3 = SPAB(feature_channels, bias=bias)
    self.block_4 = SPAB(feature_channels, bias=bias)
    self.block_5 = SPAB(feature_channels, bias=bias)
    self.block_6 = SPAB(feature_channels, bias=bias)

    self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True)
    self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1)
    self.end = conv_layer(feature_channels, self.out_channels * (upscale ** 2), kernel_size=3)
    self.upsampler = nn.PixelShuffle(upscale)
    # self.upsampler = pixelshuffle_block(feature_channels, self.out_channels, upscale_factor=upscale)
    self.quant = torch.quantization.QuantStub()
    self.dequant = torch.quantization.DeQuantStub()
    self.f_cat = nn.quantized.FloatFunctional()

def forward(self, x):
    x = self.quant(x)

    out_feature = self.conv_1(x)

    out_b1, out_feature_2 = self.block_1(out_feature)
    out_b2, out_b1_2 = self.block_2(out_b1)
    out_b3, out_b2_2 = self.block_3(out_b2)

    out_b4, out_b3_2 = self.block_4(out_b3)
    out_b5, out_b4_2 = self.block_5(out_b4)
    out_b6, out_b5_2 = self.block_6(out_b5)

    out_b6_2 = self.conv_2(out_b6)
    out = self.conv_cat(self.f_cat.cat([out_feature, out_b6_2, out_b1, out_b5_2], 1))
    out = self.end(out)
    output = self.upsampler(out)
    output = torch.clamp(output, min=0.0, max=255.0)
    output = self.dequant(output)

    return output

def fuse_model(self):
    for name, module in self.named_children():
        if isinstance(module, SPAB):
            for n, m in module.named_children():
                if isinstance(m, Conv3XC):
                    RK, RB = m.rep_param()
                    conv = Conv3XC_QAT(m.eval_conv.in_channels, m.eval_conv.out_channels, m.eval_conv.kernel_size,
                                       m.eval_conv.stride, m.eval_conv.padding)
                    conv.block.weight.data = RK
                    conv.block.bias.data = RB
                    setattr(module, n, conv)
        elif isinstance(module, Conv3XC):
            RK, RB = module.rep_param()
            conv = Conv3XC_QAT(module.eval_conv.in_channels, module.eval_conv.out_channels,
                               module.eval_conv.kernel_size,
                               module.eval_conv.stride, module.eval_conv.padding)
            conv.block.weight.data = RK
            conv.block.bias.data = RB
            setattr(self, name, conv)

model = SPAN(1, 1, upscale=2, feature_channels=48)
# model.eval()
inputs = torch.randn(1, 1, 16, 16)

model.fuse_model()
backend = 'fbgemm'
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)

# sr_inputs = (torch.rand(1, 1, 256, 256).cuda(),)
sr = model(inputs)
model.eval()
saved_model = torch.quantization.convert(model, inplace=False)
# int_sr = saved_model(inputs)
# Convert to TorchScript
# scripted_model = torch.jit.script(saved_model)
onnx_filename = "imageEn-SPAN-test.onnx"
dynamic_axes = {'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
torch.onnx.export(saved_model,
                  inputs,
                  onnx_filename,
                  input_names=['input'],
                  output_names=['output'],
                  verbose=True, 
                  do_constant_folding=True,
                  dynamic_axes=dynamic_axes,
                  export_params=True,
                  opset_version=13) `

In step sr = model(inputs), x size is (1, 1, 16, 16). After converting and export to onnx, x size is Tensor: (tensor(1), tensor(1), tensor(16), tensor(16)