PINTO0309 / PINTO_model_zoo

A repository for storing models that have been inter-converted between various frameworks. Supported frameworks are TensorFlow, PyTorch, ONNX, OpenVINO, TFJS, TFTRT, TensorFlowLite (Float32/16/INT8), EdgeTPU, CoreML.
https://qiita.com/PINTO
MIT License
3.59k stars 572 forks source link

HiFill inpainting conversion to CoreML #305

Closed DanielZanchi closed 1 year ago

DanielZanchi commented 2 years ago

Issue Type

Support

OS

Mac OS

OS architecture

aarch64

Programming Language

Python

Framework

CoreML

Model name and Weights/Checkpoints URL

https://github.com/PINTO0309/PINTO_model_zoo/tree/main/100_HiFill

Description

I am trying to convert the tensorflow model to coreML to use it on iOS devices.

Is this code correct for the conversion?

import tensorflow as tf
import coremltools as ct
import numpy as np

path = "hifill.pb"

def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")
    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))

def tf1_tf2(model_path):
    # path = "/content/sample-imageinpainting-HiFill/GPU_CPU/pb/hifill.pb"
    graph_def = tf.compat.v1.GraphDef()
    loaded = graph_def.ParseFromString(open(model_path, 'rb').read())
    inception_func = wrap_frozen_graph(
        graph_def, inputs=['img:0', 'mask:0'],
        outputs=['inpainted:0', 'attention:0', 'mask_processed:0'])
    return graph_def

imgSize = np.random.rand(1, 512, 512, 3)
maskSize = np.random.rand(1, 512, 512, 1)
tf_model = tf1_tf2(path)
mlmodel = ct.convert(tf_model,
                    source="tensorflow",
                    inputs=[ct.ImageType(name="img"), ct.ImageType(name="mask")])

mlmodel.save("hifill.mlmodel")

I obtain the mlmodel but the output is strange :(

If it helps I could provide the Xcode project I am running the test on :)

Any support on this point?

image

Relevant Log Output

No response

URL or source code for simple inference testing code

No response