tensorflow / tpu

Reference models and tools for Cloud TPUs.
https://cloud.google.com/tpu/
Apache License 2.0
5.21k stars 1.77k forks source link

Cannot import meta graph to GPU after training on TPU [KeyError: 'InfeedEnqueueTuple'] #267

Open nilakshdas opened 5 years ago

nilakshdas commented 5 years ago

I am unable to import the model to my local GPU machine after training ResNet on TPU and downloading the checkpoint directory.

Here's what I do to load the model:

checkpoint = tf.train.get_checkpoint_state(MODEL_DIR)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('%s.meta' % checkpoint.model_checkpoint_path)
    saver.restore(sess, checkpoint.model_checkpoint_path)

Which gives the following error:

Traceback (most recent call last):
  File "test.py", line 194, in <module>
    saver = tf.train.import_meta_graph('%s.meta' % checkpoint.model_checkpoint_path)
  File "xxx/python3.6/site-packages/tensorflow/python/training/saver.py", line 1666, in import_meta_graph
    meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
  File "xxx/python3.6/site-packages/tensorflow/python/training/saver.py", line 1688, in _import_meta_graph_with_return_elements
    **kwargs))
  File "xxx/python3.6/site-packages/tensorflow/python/framework/meta_graph.py", line 806, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "xxx/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "xxx/python3.6/site-packages/tensorflow/python/framework/importer.py", line 391, in import_graph_def
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
  File "xxx/python3.6/site-packages/tensorflow/python/framework/importer.py", line 158, in _RemoveDefaultAttrs
    op_def = op_dict[node.op]
KeyError: 'InfeedEnqueueTuple'

Seems like some of the ops are only defined for TPU's. How can I get around this?

I am using tensorflow 1.11.0.

sayradley commented 5 years ago

I've managed to work around this issue by rerunning an estimator training locally for one step starting from the latest checkpoint (the one from TPU). Then newly produced meta-graph was imported seamlessly.

saberkun commented 5 years ago

TPU export will add TPU metagraph to let you be able to serve on TPU. When you export for GPU, you need to export with use_tpu=false + export_to_tpu=false. Also, using the latest checkpoint and export with normal Estimator should work.

omrishsu commented 3 years ago

@saberkun can you post an example how to export with use_tpu=false + export_to_tpu=false?