baiwenjia / ukbb_cardiac

Some code for processing and analysing UK Biobank cardiac MR images.
Apache License 2.0
149 stars 65 forks source link

Loading The graph model #11

Open ghost opened 2 years ago

ghost commented 2 years ago

Hi, whenever I try to load the model I get the following error. can you please help me resolve this matter plese? thank you.

ValueError: Node 'gradients/UNet/conv0_up/batch_normalization_3/FusedBatchNorm_grad/FusedBatchNormGrad' has an _output_shapes attribute inconsistent with the GraphDef for output #3: Dimension 0 in both shapes must be equal, but are 0 and 16. Shapes are [0] and [16]

tramanhphambme commented 2 years ago

Hi, there. I just recently tried the code and also got the same issue. Could you tell me if you fix it yet? Thank you!!

iimog commented 2 years ago

A similar problem was reported elsewhere, and the their fix works here, as well.

@@ -50,11 +50,17 @@ tf.app.flags.DEFINE_float('weight_r', 0.1,

 if __name__ == '__main__':
+    gd = tf.MetaGraphDef()
+    with open('{0}.meta'.format(FLAGS.model_path), "rb") as f:
+        gd.ParseFromString(f.read())
+    for node in gd.graph_def.node:
+        if '_output_shapes' in node.attr:
+            del node.attr['_output_shapes']
     with tf.Session() as sess:
         sess.run(tf.global_variables_initializer())

         # Import the computation graph and restore the variable values
-        saver = tf.train.import_meta_graph('{0}.meta'.format(FLAGS.model_path))
+        saver = tf.train.import_meta_graph(gd)
         saver.restore(sess, '{0}'.format(FLAGS.model_path))

         print('Start evaluating on the test set ...')

With these changes in deploy_network_ao.py and deploy_network.py it works for me.