Open misaka-1 opened 8 months ago
i try fuse model before QAT, and some errors occur. `
class SPAN(nn.Module): """ Swift Parameter-free Attention Network for Efficient Super-Resolution """
def __init__(self, num_in_ch, num_out_ch, feature_channels=48, upscale=4, bias=True ): super(SPAN, self).__init__() self.in_channels = num_in_ch self.out_channels = num_out_ch self.conv_1 = Conv3XC(self.in_channels, feature_channels, gain1=2, s=1) self.block_1 = SPAB(feature_channels, bias=bias) self.block_2 = SPAB(feature_channels, bias=bias) self.block_3 = SPAB(feature_channels, bias=bias) self.block_4 = SPAB(feature_channels, bias=bias) self.block_5 = SPAB(feature_channels, bias=bias) self.block_6 = SPAB(feature_channels, bias=bias) self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True) self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1) self.end = conv_layer(feature_channels, self.out_channels * (upscale ** 2), kernel_size=3) self.upsampler = nn.PixelShuffle(upscale) # self.upsampler = pixelshuffle_block(feature_channels, self.out_channels, upscale_factor=upscale) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() self.f_cat = nn.quantized.FloatFunctional() def forward(self, x): x = self.quant(x) out_feature = self.conv_1(x) out_b1, out_feature_2 = self.block_1(out_feature) out_b2, out_b1_2 = self.block_2(out_b1) out_b3, out_b2_2 = self.block_3(out_b2) out_b4, out_b3_2 = self.block_4(out_b3) out_b5, out_b4_2 = self.block_5(out_b4) out_b6, out_b5_2 = self.block_6(out_b5) out_b6_2 = self.conv_2(out_b6) out = self.conv_cat(self.f_cat.cat([out_feature, out_b6_2, out_b1, out_b5_2], 1)) out = self.end(out) output = self.upsampler(out) output = torch.clamp(output, min=0.0, max=255.0) output = self.dequant(output) return output def fuse_model(self): for name, module in self.named_children(): if isinstance(module, SPAB): for n, m in module.named_children(): if isinstance(m, Conv3XC): RK, RB = m.rep_param() conv = Conv3XC_QAT(m.eval_conv.in_channels, m.eval_conv.out_channels, m.eval_conv.kernel_size, m.eval_conv.stride, m.eval_conv.padding) conv.block.weight.data = RK conv.block.bias.data = RB setattr(module, n, conv) elif isinstance(module, Conv3XC): RK, RB = module.rep_param() conv = Conv3XC_QAT(module.eval_conv.in_channels, module.eval_conv.out_channels, module.eval_conv.kernel_size, module.eval_conv.stride, module.eval_conv.padding) conv.block.weight.data = RK conv.block.bias.data = RB setattr(self, name, conv) model = SPAN(1, 1, upscale=2, feature_channels=48) # model.eval() inputs = torch.randn(1, 1, 16, 16) model.fuse_model() backend = 'fbgemm' model.qconfig = torch.quantization.get_default_qat_qconfig(backend) torch.quantization.prepare_qat(model, inplace=True) # sr_inputs = (torch.rand(1, 1, 256, 256).cuda(),) sr = model(inputs) model.eval() saved_model = torch.quantization.convert(model, inplace=False) # int_sr = saved_model(inputs) # Convert to TorchScript # scripted_model = torch.jit.script(saved_model) onnx_filename = "imageEn-SPAN-test.onnx" dynamic_axes = {'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}} torch.onnx.export(saved_model, inputs, onnx_filename, input_names=['input'], output_names=['output'], verbose=True, do_constant_folding=True, dynamic_axes=dynamic_axes, export_params=True, opset_version=13) `
In step sr = model(inputs), x size is (1, 1, 16, 16). After converting and export to onnx, x size is Tensor: (tensor(1), tensor(1), tensor(16), tensor(16)
sr = model(inputs)
i try fuse model before QAT, and some errors occur. `
@ARCH_REGISTRY.register()
class SPAN(nn.Module): """ Swift Parameter-free Attention Network for Efficient Super-Resolution """
In step
sr = model(inputs)
, x size is (1, 1, 16, 16). After converting and export to onnx, x size is Tensor: (tensor(1), tensor(1), tensor(16), tensor(16)