microsoft / Bringing-Old-Photos-Back-to-Life

Bringing Old Photo Back to Life (CVPR 2020 oral)
https://arxiv.org/abs/2004.09484
MIT License
15.14k stars 2k forks source link

Face_Enhancement model converting to onnx failure #204

Open cjenkins5614 opened 3 years ago

cjenkins5614 commented 3 years ago

Hello,

Thanks for the great work. I'm trying to convert this model into onnx, but have met a few issues.

The mv and dot operator used by PyTorch's spectral_norm was one of them. Following https://github.com/onnx/onnx/issues/3006#issuecomment-690303884 I coverted them to matmul in my own implementation of spectral_norm and the issue went away.

Now it's complaining:

Traceback (most recent call last):
    out = torch.onnx.export(model, input_dict["image"], "model.onnx", verbose=False, opset_version=11,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 271, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 694, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 206, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 309, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 997, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 148, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1285, in batch_norm
    if weight is None or sym_help._is_none(weight):
RuntimeError: Unsupported: ONNX export of batch_norm for unknown channel size.

The code to convert this is:

    opt = EasyDict(aspect_ratio=1.0,
                checkpoints_dir='Face_Enhancement/checkpoints',
                contain_dontcare_label=False,
                crop_size=256,
                gpu_ids=[0],
                init_type='xavier',
                init_variance=0.02,
                injection_layer='all',
                isTrain=False,
                label_nc=18,
                load_size=256,
                model='pix2pix',
                name='Setting_9_epoch_100',
                nef=16,
                netG='spade',
                ngf=64,
                no_flip=True,
                no_instance=True,
                no_parsing_map=True,
                norm_D='spectralinstance',
                norm_E='spectralinstance',
                # norm_G='spectralspadebatch3x3',
                norm_G='spectralspadesyncbatch3x3',
                num_upsampling_layers='normal',
                output_nc=3,
                preprocess_mode='resize',
                semantic_nc=18,
                use_vae=False,
                which_epoch='latest',
                z_dim=256)

    model = Pix2PixModel(opt)
    model.eval()

    input_dict = {
        "label": torch.zeros((1, 18, 256, 256)),
        "image": torch.randn(1, 3, 256, 256),
        "path": None,
    }

    # from torchsummary import summary
    # summary(model, (3, 256, 256))
    out = torch.onnx.export(model, input_dict, "model.onnx", verbose=False, opset_version=11,
                      input_names = ['input'],
                      output_names = ['output'])

I printed out the graph g from https://github.com/pytorch/pytorch/blob/e56d3b023818f54553f2dc5d30b6b7aaf6b6a325/torch/onnx/symbolic_opset9.py#L1337

...
  %450 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPULongType{2} ]]()
  %451 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %452 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 0  0 [ CPULongType{2} ]]()
  %453 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %454 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %455 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %456 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %457 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %458 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%436, %447, %netG.head_0.conv_1.bias) # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:395:0
  %459 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %460 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Add(%266, %458) # /workdir/Face_Enhancement/models/networks/architecture.py:56:0
  %461 : None = prim::Constant()
  %462 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %463 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPUFloatType{2} ]]()
  %464 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %465 : Float(4, strides=[1], device=cpu) = onnx::Concat[axis=0](%463, %464)
  %466 : Float(0, strides=[1], device=cpu) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
  %468 : None = prim::Constant()
  %469 : None = prim::Constant()
  %470 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %471 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={0.1}]()
  %472 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={1e-05}]()
  %473 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  return ()

ipdb> input
467 defined in (%467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
)
ipdb> weight
468 defined in (%468 : None = prim::Constant()
)
ipdb> bias
469 defined in (%469 : None = prim::Constant()
)

Float(*, *, *, * stood out to me but I'm not sure how to interpret this. I removed all spectral_norm calls, as well as trying to change synced batch norm to batch norm, but the issue still persists.

RostyslavBryiovskyi commented 2 years ago

Hello @cjenkins5614. I have the same problem. Did you fix it ? UPD: Fixed it.

cjenkins5614 commented 2 years ago

How did it work out in your case? @RostyslavBryiovskyi

I haven't figured it out yet.

zhangmozhe commented 2 years ago

This is the bug of onnx. The Upsample layer cannot be placed in front of the BN/IN. https://github.com/pytorch/pytorch/issues/69346. Try to insert an conv2d layer with identity kernel as a warkaround.

DonggeunYu commented 2 years ago

@zhangmozhe I'm not sure how to utilize identity matrices. I tried using torch.Identity and torch.Conv2d, but ONNX ignores those layers. As before, the channel size is *.

yuananf commented 2 years ago

I got the same problem here, can anyone clearly show an example how to export onnx successfully?

July250229 commented 2 years ago

I met the same problem. Can anyone post a solution for this problem?

tongchangD commented 2 years ago

me too