Closed Fchaubard closed 8 years ago
Running the following after download gives error at ParseFromString:
import tensorflow as tf import sys import skimage import skimage.io import skimage.transform import numpy as np synset = [l.strip() for l in open('/home/ubuntu/tensorflow-vgg16/synset.txt').readlines()] VGG_MEAN = [103.939, 116.779, 123.68] # returns image of shape [224, 224, 3] # [height, width, depth] def load_image(path): # load image img = skimage.io.imread(path) img = img / 255.0 assert (0 <= img).all() and (img <= 1.0).all() #print "Original Image Shape: ", img.shape # we crop image from center short_edge = min(img.shape[:2]) yy = int((img.shape[0] - short_edge) / 2) xx = int((img.shape[1] - short_edge) / 2) crop_img = img[yy : yy + short_edge, xx : xx + short_edge] # resize to 224, 224 resized_img = skimage.transform.resize(crop_img, (224, 224)) return resized_img # returns the top1 string def print_prob(prob): #print prob print "prob shape", prob.shape pred = np.argsort(prob)[::-1] # Get top1 label top1 = synset[pred[0]] print "Top1: ", top1 # Get top5 label top5 = [synset[pred[i]] for i in range(5)] print "Top5: ", top5 return top1 with open("/home/ubuntu/vgg16-v4.tfmodel", mode='rb') as f: fileContent = f.read() graph_def = tf.GraphDef() graph_def.ParseFromString(fileContent) images = tf.placeholder("float", [None, 224, 224, 3]) tf.import_graph_def(graph_def, input_map={ "images": images }) print "graph loaded from disk" graph = tf.get_default_graph() cat = load_image("cat.jpg") with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) print "variables initialized" batch = cat.reshape((1, 224, 224, 3)) assert batch.shape == (1, 224, 224, 3) feed_dict = { images: batch } prob_tensor = graph.get_tensor_by_name("import/prob:0") prob = sess.run(prob_tensor, feed_dict=feed_dict) print_prob(prob[0])
Error
--------------------------------------------------------------------------- DecodeError Traceback (most recent call last) <ipython-input-1-c8f1d9f927de> in <module>() 48 49 graph_def = tf.GraphDef() ---> 50 graph_def.ParseFromString(fileContent) 51 52 images = tf.placeholder("float", [None, 224, 224, 3]) /usr/local/lib/python2.7/dist-packages/google/protobuf/message.py in ParseFromString(self, serialized) 183 """ 184 self.Clear() --> 185 self.MergeFromString(serialized) 186 187 def SerializeToString(self): /usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in MergeFromString(self, serialized) 1006 length = len(serialized) 1007 try: -> 1008 if self._InternalParse(serialized, 0, length) != length: 1009 # The only reason _InternalParse would return early is if it 1010 # encountered an end-group tag. /usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end) 1042 pos = new_pos 1043 else: -> 1044 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1045 if field_desc: 1046 self._UpdateOneofState(field_desc) /usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py in DecodeRepeatedField(buffer, pos, end, message, field_dict) 626 new_pos = pos + size 627 if new_pos > end: --> 628 raise _DecodeError('Truncated message.') 629 # Read sub-message. 630 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: DecodeError: Truncated message.
This is a bug in TF. Use TF 0.6 and it should work.
https://github.com/tensorflow/tensorflow/issues/582
@Fchaubard I'm getting the same error as you. Have you solved it? I tried revise the 64<<20 to 256<<20 by follow @ry link. It doesn't work for me.
Running the following after download gives error at ParseFromString:
Error