HadoopIt / rnn-nlu

A TensorFlow implementation of Recurrent Neural Networks for Sequence Classification and Sequence Labeling
483 stars 171 forks source link

Restoring model from checkpoint #10

Open dianamurgulet opened 7 years ago

dianamurgulet commented 7 years ago

Hi @HadoopIt ,

Thank you for publishing the code for the paper. I am trying to use a stored pre-trained model to generate the intent and slots for a new sentence. However, based on the outputs it generates, it ends up using a new, untrained model.

saver = tf.train.import_meta_graph('/tmp/model.ckpt-1900.meta')
saver.restore(session, '/tmp/model.ckpt-1900')

model_train, model_test = create_model(session, 139, 36, 6)
step_outputs  = model_test.joint_step(session, encoder_inputs, tags, tag_weights, labels,sequence_length, bucket_id, True)

Any suggestions on how to use a trained model from a stored file?

wqdong commented 7 years ago

The train and test model share parameters, since you didn't load the train model, the test model would be randomly initialized, also you should create the models before restore.

HadoopIt commented 7 years ago

Hi @dianamurgulet, the "create_mode" function finds the model checkpoint from FLAGS.train_dir with ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir). If you have a pre-trained model, just pass the model directory to this get_checkpoint_state function should make it work.

programmer9208 commented 7 years ago

Hi @HadoopIt , I tried to pass the model directory to "create_model" func at your suggestion, but it still didn't work. Then I read your code for restoring pre-trained model from disk in run_multi-task_rnn.py, and made some change:

original code in run_multi-task_rnn.py, line 219:
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): I simply delete the second condition: if ckpt:

Then it works. I guess the reason may be ckpt.model_checkpoint_path = './model/model.ckpt-29900' and it doesn't exist. The files we need are with suffix '.data-00000-of-00001' , 'index', 'meta', which are exist in $train_dir.

bringtree commented 6 years ago

@pachirayz tf.gfile.Exists(ckpt.model_checkpoint_path) always false. maybe you should edit it

bringtree commented 6 years ago

ckpt.model_checkpoint_path Out[2]: u'model_tmp/model.ckpt-15000' sss = '/Users/huangpeisong/Desktop/py2 2/rnn-nlu/model_tmp/model.ckpt-15000' tf.gfile.Exists(sss) Out[4]: False sss = '/Users/huangpeisong/Desktop/py2 2/rnn-nlu/model_tmp/model.ckpt-15000.index' tf.gfile.Exists(sss) Out[6]: True

HadoopIt commented 6 years ago

@bringtree Thanks for pointing this out. I have just pushed a fix.

programmer9208 commented 6 years ago

@bringtree Yeah, that's what I mean. The second condition (tf.gfile.Exists(ckpt.model_checkpoint_path)) is always false because ckpt.model_checkpoint_path does not contains suffix , e.g 'index' as your example. Thank you all ! @HadoopIt @bringtree