rpautrat / SuperPoint

Efficient neural feature detector and descriptor
MIT License
1.88k stars 416 forks source link

pb.file support dynamic batch prediction? #240

Closed tangchen2 closed 2 years ago

tangchen2 commented 2 years ago

Hi, first thanks to the awesome work, when i use your code finetune, i use export_model.py to convert the ckpt to pb files, how ever i found that the pb file doesn't support batch prediction? the error is "Invalid argument: Input shape axis 0 must equal 1, got shape [2, 240, 320, 1]", i found it seemd to be related to pred_batch_size in config, however i want to export dynamic batch pb file, could you give me some tips, thanks so much!

Below is part of my pred code

graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], str(tf_path)) print("SuperPoint loaded from {}".format(tf_path))

input_img_tensor = graph.get_tensor_by_name('superpoint/image:0')
output_prob_nms_tensor = graph.get_tensor_by_name('superpoint/prob_nms:0')
output_desc_tensors = graph.get_tensor_by_name('superpoint/descriptors:0')

tf_out1 = sess.run([output_prob_nms_tensor, output_desc_tensors],
                feed_dict={input_img_tensor: batch_img})
rpautrat commented 2 years ago

Hi, sorry I don't know how to fix that for now, and I don;t have too much time to look into it. But this seems like a general question related to Tensorflow, so maybe you can get some advice on forums online?

ChLee98 commented 2 years ago

Hi, first thanks to the awesome work, when i use your code finetune, i use export_model.py to convert the ckpt to pb files, how ever i found that the pb file doesn't support batch prediction? the error is "Invalid argument: Input shape axis 0 must equal 1, got shape [2, 240, 320, 1]", i found it seemd to be related to pred_batch_size in config, however i want to export dynamic batch pb file, could you give me some tips, thanks so much!

Below is part of my pred code

graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], str(tf_path)) print("SuperPoint loaded from {}".format(tf_path))

input_img_tensor = graph.get_tensor_by_name('superpoint/image:0')
output_prob_nms_tensor = graph.get_tensor_by_name('superpoint/prob_nms:0')
output_desc_tensors = graph.get_tensor_by_name('superpoint/descriptors:0')

tf_out1 = sess.run([output_prob_nms_tensor, output_desc_tensors],
                feed_dict={input_img_tensor: batch_img})

Hi,I also meet this problem(Input shape axis 0 must equal 1, got shape), did you finally resolve it?

ChLee98 commented 2 years ago

Hi, first thanks to the awesome work, when i use your code finetune, i use export_model.py to convert the ckpt to pb files, how ever i found that the pb file doesn't support batch prediction? the error is "Invalid argument: Input shape axis 0 must equal 1, got shape [2, 240, 320, 1]", i found it seemd to be related to pred_batch_size in config, however i want to export dynamic batch pb file, could you give me some tips, thanks so much!

Below is part of my pred code

graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], str(tf_path)) print("SuperPoint loaded from {}".format(tf_path))

input_img_tensor = graph.get_tensor_by_name('superpoint/image:0')
output_prob_nms_tensor = graph.get_tensor_by_name('superpoint/prob_nms:0')
output_desc_tensors = graph.get_tensor_by_name('superpoint/descriptors:0')

tf_out1 = sess.run([output_prob_nms_tensor, output_desc_tensors],
                feed_dict={input_img_tensor: batch_img})

I don't know if you've solved the problem. You can find the function _pred_graph() in "base_model.py". The first line of code for this function: pred_out = self._gpu_tower. (data, Mode.PRED, self.config ['pred_batch_size']) The default value of 'pred_batch_size' is 1 in"base_model.py". You can add the parameter 'pred_batch_size' to the config file and set the required batch_size or change the default value of the parameter in "base_model.py".