sirius-ai / MobileFaceNet_TF

Tensorflow implementation for MobileFaceNet
Apache License 2.0
434 stars 170 forks source link

Trying to read ckpt but get ValueError: #78

Closed neesetifa closed 3 years ago

neesetifa commented 3 years ago

Just try to do a simple forward inference, all the codes are just copy-and-paste from test_nets.py, nothing's changed except added compat.v1 for I'm using TF 2.x

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files) > 1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.compat.v1.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file

def load_model(model):
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with tf.compat.v1.gfile.FastGFile(model_exp, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.compat.v1.import_graph_def(graph_def, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))
        saver.restore(tf.compat.v1.get_default_session(), os.path.join(model_exp, ckpt_file))

def main():
    image = load_image('xxx.jpg')
    load_model('./')   # all the pretrained file from arch/pretrained_model/ are put in the same directory

    with tf.compat.v1.Graph().as_default():
        with tf.compat.v1.Session() as sess:
            pdb.set_trace()
            inputs_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.compat.v1.get_default_graph().get_tensor_by_name("embeddings:0")
            feed_dict = {inputs_placeholder: image}
            output = sess.run(embeddings, feed_dict=feed_dict)

But get following error:

Model directory: ./
Metagraph file: MobileFaceNet_TF.ckpt.meta
Checkpoint file: MobileFaceNet_TF.ckpt
Traceback (most recent call last):
  File "aa.py", line 79, in <module>
    main()
  File "aa.py", line 67, in main
    load_model('./')
  File "aa.py", line 61, in load_model
    saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))
  File "/usr/lib/python3.8/site-packages/tensorflow/python/training/saver.py", line 1460, in import_meta_graph
    return _import_meta_graph_with_return_elements(meta_graph_or_file,
  File "/usr/lib/python3.8/site-packages/tensorflow/python/training/saver.py", line 1481, in _import_meta_graph_with_return_elements
    meta_graph.import_scoped_meta_graph_with_return_elements(
  File "/usr/lib/python3.8/site-packages/tensorflow/python/framework/meta_graph.py", line 794, in import_scoped_meta_graph_with_return_elements
    imported_return_elements = importer.import_graph_def(
  File "/usr/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/lib/python3.8/site-packages/tensorflow/python/framework/importer.py", line 400, in import_graph_def
    return _import_graph_def_internal(
  File "/usr/lib/python3.8/site-packages/tensorflow/python/framework/importer.py", line 501, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Node 'gradients/MobileFaceNet/Logits/LinearConv1x1/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad' has an _output_shapes attribute inconsistent with the GraphDef for output #3: Dimension 0 in both shapes must be equal, but are 0 and 192. Shapes are [0] and [192].

any help or hint?

wei8171023 commented 2 years ago

吧--embedding_size默认值设为192