TencentARC / GFPGAN

GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration.
Other
35.6k stars 5.89k forks source link

converting GFPGANv1.3.pth to onnx format #480

Open prashant-saxena opened 8 months ago

prashant-saxena commented 8 months ago

Windows 10 Python 3.10.0 torch 2.1.2 gfpgan 1.3.8 onnx 1.14.0 onnxruntime 1.15.0

import torch
import torch.onnx

import onnx
import onnxruntime

from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

onnx_path = "models/GFPGANv1.3.onnx"
gfpgan_path = "models/GFPGANv1.3.pth"

torch_model = GFPGANv1Clean(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=2,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)

# Load pretrained model weights
loadnet = torch.load(gfpgan_path)

# Initialize model with the pretrained weights
torch_model.load_state_dict(loadnet, strict=False)

torch_model.to(device)

# Set the model to evaluation mode
torch_model.eval()

# Input to the model
inputs = torch.ones((1, 3, 512, 512)).to(device)

# Export model in onnx format
torch.onnx.export(torch_model,               # model being run
                  inputs,                    # model input (or a tuple for multiple inputs)
                  onnx_path,                 # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  verbose=False,
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names=['input'],     # the model's input names
                  output_names=['output'],   # the model's output names
                  opset_version=12           # the ONNX version to export the model to
                  )

# Verify you onnx model
onnx_model = onnx.load("models/GFPGANv1.3.onnx")
onnx.checker.check_model(onnx_model)

Errors:

D:\roop\.venv\lib\site-packages\torchvision\transforms\functional_tensor.py:5:
UserWarning: The torchvision.transforms.functional_tensor module is deprecated
in 0.15 and will be **removed in 0.17**. Please don't rely on it. You probably
just need to use APIs in torchvision.transforms.functional or in
torchvision.transforms.v2.functional. warnings.warn( 

D:\roop\.venv\lib\site-packages\gfpgan\archs\gfpganv1_clean_arch.py:314:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to
be incorrect. We can't record the data flow of Python values, so this value
will be treated as a constant in the future. This means that the trace might
not generalize to other inputs! if return_rgb:

D:\roop\.venv\lib\site-packages\gfpgan\archs\gfpganv1_clean_arch.py:62:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to
be incorrect. We can't record the data flow of Python values, so this value
will be treated as a constant in the future. This means that the trace might
not generalize to other inputs! if randomize_noise:

D:\roop\.venv\lib\site-packages\gfpgan\archs\gfpganv1_clean_arch.py:102:
TracerWarning: Converting a tensor to a Python integer might cause the trace to
be incorrect. We can't record the data flow of Python values, so this value
will be treated as a constant in the future. This means that the trace might
not generalize to other inputs! out_same, out_sft = torch.split(out, int
(out.size(1) // 2), dim=1)

D:\roop\.venv\lib\site-packages\gfpgan\archs\gfpganv1_clean_arch.py:114:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to
be incorrect. We can't record the data flow of Python values, so this value
will be treated as a constant in the future. This means that the trace might
not generalize to other inputs! if return_latents:

D:\roop\.venv\lib\site-packages\torch\onnx\utils.py:1686: UserWarning: The
exported ONNX model failed ONNX shape inference. The model will not be
executable by the ONNX Runtime. If this is unintended and you believe there is
a bug, please report an issue at https://github.com/pytorch/pytorch/issues.
Error reported by strict ONNX shape inference: [ShapeInferenceError]
(op_type:Conv, node name: /stylegan_decoder/to_rgb1/modulated_conv/Conv): W has
inconsistent type tensor(float) (Triggered internally
at ..\torch\csrc\jit\serialization\export.cpp:1421.)
  _C._check_onnx_proto(proto)

I'm getting GFPGANv1.3.onnx model but not sure about these errors. I recently started diving in model writing and training. I think the initial warnings about data type are because of gfpgan 1.3.8. It was last released on Sep 16, 2022, and there must be some inconsistencies between torch 2.1.2 (Dec 15, 2023).

When I try to create onnx runtime using:

# compute the output using ONNX Runtime’s Python APIs
ort_session = onnxruntime.InferenceSession("models/GFPGANv1.3.onnx", providers=["CPUExecutionProvider"])

again I'm getting an error which the first stage was stating:

Traceback (most recent call last):
  File "D:\roop\convert.py", line 56, in <module>
    ort_session = onnxruntime.InferenceSession("models/GFPGANv1.3.onnx", providers=["CPUExecutionProvider"])
  File "D:\roop\.venv\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 383, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "D:\roop\.venv\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 424, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from models/GFPGANv1.3.onnx failed:Type Error: Type parameter (T) of Optype (Conv) bound to different types (tensor(double) and tensor(float) in node (/stylegan_decoder/to_rgb1/modulated_conv/Conv).

I would appreciate it if someone pointed out the solution for these errors and more importantly shed some light about this conversion process. Earlier I was under the impression that this conversion would be straight forward and can be done by some external script but I believe I was wrong.

ystoneman commented 8 months ago

If you're using a GPU, could you please also include the version of CUDA you're using? print(torch.version.cuda) will give that to you.

Also, have you posted your question on any ONNX or PyTorch forums as well? If so, could you include links to those posts here too please?