hzy46 / fast-neural-style-tensorflow

A tensorflow implementation for fast neural style!
932 stars 361 forks source link

Error after model quantisation #48

Open kmonachopoulos opened 6 years ago

kmonachopoulos commented 6 years ago

Describe the problem I am using the pre-trained style transfer models from Baidu Drive. I have converted them to .pb format using the export() function. When I inference a single image through these fp32 models then I can see the styled image at the output of the network and everything works fine. Later on I optimize the model for inference using transform_graph tool with these parameters :

--transforms='
  add_default_attributes
  strip_unused_nodes(type=float)
  remove_nodes(op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights
  quantize_nodes
  strip_unused_nodes
  sort_by_execution_order'

the quantised models (.pb) generated successfully but now I have problem with inference when I execute sess.run().

Error Output :

Traceback (most recent call last):
  File "Inf_Image_pb.py", line 89, in <module>
    Session_out = sess.run(l_output, feed_dict={l_input: image})            
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 889, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1317, in _do_run
    options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [474,712,3] != values[2].shape = []
     [[Node: Reshape/shape = Pack[N=3, T=DT_INT32, axis=0, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_image_0_0, _arg_input_image_0_0, Reshape/shape/2)]]

Caused by op u'Reshape/shape', defined at:
  File "Inf_Image_pb.py", line 74, in <module>
    producer_op_list=None
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 313, in import_graph_def
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Shapes of all inputs must match: values[0].shape = [474,712,3] != values[2].shape = []
     [[Node: Reshape/shape = Pack[N=3, T=DT_INT32, axis=0, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_image_0_0, _arg_input_image_0_0, Reshape/shape/2)]]

Is this a model problem or a transform graph problem ?? What is this error refers to ? I can't find any useful information online ..