NVlabs / SPADE

Semantic Image Synthesis with SPADE
https://nvlabs.github.io/SPADE/
Other
7.61k stars 980 forks source link

Converting to onnx #154

Open cjenkins5614 opened 3 years ago

cjenkins5614 commented 3 years ago

Hello,

Thanks for the great work. I'm trying to convert this model into onnx, but have met a few issues.

The mv and dot operator used by PyTorch's spectral_norm was one of them. Following https://github.com/onnx/onnx/issues/3006#issuecomment-690303884 I coverted them to matmul in my own implementation of spectral_norm and the issue went away.

Now it's complaining:

Traceback (most recent call last):
    out = torch.onnx.export(model, input_dict["image"], "model.onnx", verbose=False, opset_version=11,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 271, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 694, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 206, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 309, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 997, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 148, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1285, in batch_norm
    if weight is None or sym_help._is_none(weight):
RuntimeError: Unsupported: ONNX export of batch_norm for unknown channel size.

The code to convert this is:

    opt = EasyDict(aspect_ratio=1.0,
                checkpoints_dir='Face_Enhancement/checkpoints',
                contain_dontcare_label=False,
                crop_size=256,
                gpu_ids=[0],
                init_type='xavier',
                init_variance=0.02,
                injection_layer='all',
                isTrain=False,
                label_nc=18,
                load_size=256,
                model='pix2pix',
                name='Setting_9_epoch_100',
                nef=16,
                netG='spade',
                ngf=64,
                no_flip=True,
                no_instance=True,
                no_parsing_map=True,
                norm_D='spectralinstance',
                norm_E='spectralinstance',
                # norm_G='spectralspadebatch3x3',
                norm_G='spectralspadesyncbatch3x3',
                num_upsampling_layers='normal',
                output_nc=3,
                preprocess_mode='resize',
                semantic_nc=18,
                use_vae=False,
                which_epoch='latest',
                z_dim=256)

    model = Pix2PixModel(opt)
    model.eval()

    input_dict = {
        "label": torch.zeros((1, 18, 256, 256)),
        "image": torch.randn(1, 3, 256, 256),
        "path": None,
    }

    # from torchsummary import summary
    # summary(model, (3, 256, 256))
    out = torch.onnx.export(model, input_dict, "model.onnx", verbose=False, opset_version=11,
                      input_names = ['input'],
                      output_names = ['output'])

I printed out the graph g from https://github.com/pytorch/pytorch/blob/e56d3b023818f54553f2dc5d30b6b7aaf6b6a325/torch/onnx/symbolic_opset9.py#L1337

...
  %450 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPULongType{2} ]]()
  %451 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %452 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 0  0 [ CPULongType{2} ]]()
  %453 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %454 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %455 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %456 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %457 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %458 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%436, %447, %netG.head_0.conv_1.bias) # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:395:0
  %459 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %460 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Add(%266, %458) # /workdir/Face_Enhancement/models/networks/architecture.py:56:0
  %461 : None = prim::Constant()
  %462 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %463 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPUFloatType{2} ]]()
  %464 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %465 : Float(4, strides=[1], device=cpu) = onnx::Concat[axis=0](%463, %464)
  %466 : Float(0, strides=[1], device=cpu) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
  %468 : None = prim::Constant()
  %469 : None = prim::Constant()
  %470 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %471 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={0.1}]()
  %472 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={1e-05}]()
  %473 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  return ()

ipdb> input
467 defined in (%467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
)
ipdb> weight
468 defined in (%468 : None = prim::Constant()
)
ipdb> bias
469 defined in (%469 : None = prim::Constant()
)

Float(*, *, *, * stood out to me but I'm not sure how to interpret this.

ymzlygw commented 2 years ago

Hi , did you find any solution now?

eaidova commented 2 years ago

Not sure that this solution is right, I downgraded torch to 1.7 and then model converted to onnx. Looks like some bug on torch to onnx conversion side that upsample in new versions to produce dynamic shapes which lead to error for batch norms