tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.42k stars 3.48k forks source link

Unable to freeze bidirectional LSTM #1616

Open stefan-falk opened 5 years ago

stefan-falk commented 5 years ago

I am using the following to define some input/output nodes for a t2t model:

def load_translation_model(ckpt_dir, config):
    d = get_model_hparams(ckpt_dir)
    hparams = d['hparams']
    problem = d['problem']
    model = d['model']

    # Modifying the graph
    decode_length = 100
    input_node_name = 'encoder_inputs'
    output_node_name = 'decoder_outputs'

    input_node = tf.placeholder(dtype=tf.int32, shape=[None, None, 1, 1], name=input_node_name)

    encoded_dict = problem.preprocess_example(
        {'inputs': input_node, 'targets': [0]}, tf.estimator.ModeKeys.PREDICT, hparams
    )

    outputs = model.infer(
        encoded_dict, beam_size=config['beam_size'], alpha=config['alpha'], decode_length=decode_length
    )
    output_node = tf.identity(outputs['outputs'], name=output_node_name)
    inputs = {input_node_name: input_node}
    outputs = {output_node_name: output_node}
    return model, inputs, outputs

Taking this and freeze the graph after saving the model work (at least for the Transformer model). However, it fails as I try to freeze a lstm_seq2seq_attention_bidirectional_encoder model.

    freeze.freeze_graph(
        input_graph=input_graph,
        input_checkpoint=ckpt_path,
        output_graph=output_graph,
        output_node_names=','.join(output_node_names),
        input_binary=True,
        input_saver='',
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        initializer_nodes=''
    )

The issue seems to go deeper. I have found this (https://github.com/tensorflow/tensorflow/issues/24591) which seems to be related to this issue.

Does anybody else have this problem or knows how to solve this?

Full trace of the error:

Traceback (most recent call last):
  File "/home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py", line 392, in <module>
    main()
  File "/home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py", line 377, in main
    dump_latest_checkpoint(output_dir, ckpt_path, config, graph_builder=graph_builder)
  File "/home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py", line 290, in dump_latest_checkpoint
    freeze_graph(graph_file, frozen_graph_file, ckpt_path, output_nodes.keys())
  File "/home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py", line 193, in freeze_graph
    initializer_nodes=''
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 363, in freeze_graph
    checkpoint_version=checkpoint_version)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 190, in freeze_graph_with_def_protos
    var_list=var_list, write_version=checkpoint_version)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1102, in __init__
    self.build()
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1114, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1151, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 773, in _build_internal
    saveables = self._ValidateAndSliceInputs(names_to_saveables)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 680, in _ValidateAndSliceInputs
    for converted_saveable_object in self.SaveableObjectsForOp(op, name):
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 654, in SaveableObjectsForOp
    variable, "", name)
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 128, in __init__
    self.handle_op = var.op.inputs[0]
  File "/home/sfalk/miniconda3/envs/t2t/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2128, in __getitem__
    return self._inputs[i]
IndexError: list index out of range
jiarongqiu commented 5 years ago

I am facing the same issue when trying to freeze a simple RNN

jiarongqiu commented 5 years ago

My problem is solved. It is caused by tf.nn.rnn_cell.DropoutWrapper. Just delete it in the model definition part.

stefan-falk commented 5 years ago

@jiarongqiu Thanks for reporting back with a solution. However, another model I tried to freeze was this:

def get_model(input_dict_size, output_dict_size, max_input_length, max_output_length):
    # Defining the encoder
    encoder = Embedding(input_dict_size, 128, input_length=max_input_length, mask_zero=True)(encoder_input)
    encoder = Bidirectional(LSTM(128, return_sequences=True, unroll=True), merge_mode='concat')(encoder)
    encoder_last = encoder[:, -1, :]
    # Defining the decoder
    decoder_input = Input(shape=(max_output_length,))
    decoder = Embedding(output_dict_size, 256, input_length=max_output_length, mask_zero=True)(decoder_input)
    decoder = LSTM(256, return_sequences=True, unroll=True)(decoder, initial_state=[encoder_last, encoder_last])
    return decoder

I am not calling tf.nn.rnn_cell.DropoutWrapper - unless it's being used internally?