ry / tensorflow-vgg16

conversation of caffe vgg16 model to tensorflow
672 stars 285 forks source link

Downloaded model vgg16-v4.tfmodel doesnt Parse #4

Closed Fchaubard closed 8 years ago

Fchaubard commented 8 years ago

Running the following after download gives error at ParseFromString:


import tensorflow as tf
import sys

import skimage
import skimage.io
import skimage.transform
import numpy as np

synset = [l.strip() for l in open('/home/ubuntu/tensorflow-vgg16/synset.txt').readlines()]
VGG_MEAN = [103.939, 116.779, 123.68]

# returns image of shape [224, 224, 3]
# [height, width, depth]
def load_image(path):
  # load image
  img = skimage.io.imread(path)
  img = img / 255.0
  assert (0 <= img).all() and (img <= 1.0).all()
  #print "Original Image Shape: ", img.shape
  # we crop image from center
  short_edge = min(img.shape[:2])
  yy = int((img.shape[0] - short_edge) / 2)
  xx = int((img.shape[1] - short_edge) / 2)
  crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
  # resize to 224, 224
  resized_img = skimage.transform.resize(crop_img, (224, 224))
  return resized_img

# returns the top1 string
def print_prob(prob):
  #print prob
  print "prob shape", prob.shape
  pred = np.argsort(prob)[::-1]

  # Get top1 label
  top1 = synset[pred[0]]
  print "Top1: ", top1
  # Get top5 label
  top5 = [synset[pred[i]] for i in range(5)]
  print "Top5: ", top5
  return top1

with open("/home/ubuntu/vgg16-v4.tfmodel", mode='rb') as f:
    fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"

graph = tf.get_default_graph()

cat = load_image("cat.jpg")

with tf.Session() as sess:
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"

  batch = cat.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)

  feed_dict = { images: batch }

  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)

print_prob(prob[0])

Error

---------------------------------------------------------------------------
DecodeError                               Traceback (most recent call last)
<ipython-input-1-c8f1d9f927de> in <module>()
     48 
     49 graph_def = tf.GraphDef()
---> 50 graph_def.ParseFromString(fileContent)
     51 
     52 images = tf.placeholder("float", [None, 224, 224, 3])

/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py in ParseFromString(self, serialized)
    183     """
    184     self.Clear()
--> 185     self.MergeFromString(serialized)
    186 
    187   def SerializeToString(self):

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in MergeFromString(self, serialized)
   1006     length = len(serialized)
   1007     try:
-> 1008       if self._InternalParse(serialized, 0, length) != length:
   1009         # The only reason _InternalParse would return early is if it
   1010         # encountered an end-group tag.

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end)
   1042         pos = new_pos
   1043       else:
-> 1044         pos = field_decoder(buffer, new_pos, end, self, field_dict)
   1045         if field_desc:
   1046           self._UpdateOneofState(field_desc)

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py in DecodeRepeatedField(buffer, pos, end, message, field_dict)
    626         new_pos = pos + size
    627         if new_pos > end:
--> 628           raise _DecodeError('Truncated message.')
    629         # Read sub-message.
    630         if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:

DecodeError: Truncated message.
ry commented 8 years ago

This is a bug in TF. Use TF 0.6 and it should work.

https://github.com/tensorflow/tensorflow/issues/582

zxzhijia commented 8 years ago

@Fchaubard I'm getting the same error as you. Have you solved it? I tried revise the 64<<20 to 256<<20 by follow @ry link. It doesn't work for me.