suriyadeepan / practical_seq2seq

A simple, minimal wrapper for tensorflow's seq2seq module, for experimenting with datasets rapidly
http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/
GNU General Public License v3.0
570 stars 270 forks source link

Unable to predict answers with the pre-trained model #77

Open lcarnevale opened 5 years ago

lcarnevale commented 5 years ago

Hello, I have an issue predicting the answer with the pre-trained dataset.

I am using this code: `import tensorflow as tf import numpy as np

from datasets.twitter import data import data_utils

metadata, idx_q, idx_a = data.load_data(PATH='datasets/twitter/') (trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a)

xseq_len = trainX.shape[-1] yseq_len = trainY.shape[-1] batch_size = 32 xvocab_size = len(metadata['idx2w']) yvocab_size = xvocab_size emb_dim = 1024

import seq2seq_wrapper

model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len, yseq_len=yseq_len, xvocab_size=xvocab_size, yvocab_size=yvocab_size, ckpt_path='ckpt/twitter/', emb_dim=emb_dim, num_layers=3 )

val_batch_gen = data_utils.rand_batch_gen(validX, validY, 256) test_batch_gen = data_utils.rand_batch_gen(testX, testY, 256) train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)

sess = model.restore_last_session()

sess = model.train(train_batch_gen, val_batch_gen)

print ("trained")

input_ = test_batchgen.next()[0] print input output = model.predict(sess, input) print(output.shape)

replies = [] for ii, oi in zip(input_.T, output): q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ') decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ') if decoded.count('unk') == 0: if decoded not in replies: print('q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded))) replies.append(decoded)`

and the following files are in ckpt/twitter folder:

`Traceback (most recent call last): File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1022, in _do_call return fn(*args) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1004, in _run_fn status, run_metadata) File "/usr/lib/python3.6/contextlib.py", line 88, in exit next(self.gen) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.NotFoundError: Key decoder/embedding_rnn_seq2seq/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/weights/Adam_1 not found in checkpoint [[Node: save/RestoreV2_49 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_49/tensor_names, save/RestoreV2_49/shape_and_slices)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "03-Twitter-chatbot.py", line 44, in sess = model.restore_last_session() File "/home/lcarnevale/git/practical_seq2seq/src/seq2seq_wrapper.py", line 171, in restore_last_session saver.restore(sess, ckpt.model_checkpoint_path) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1439, in restore {self.saver_def.filename_tensor_name: save_path}) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 767, in run run_metadata_ptr) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 965, in _run feed_dict_string, options, run_metadata) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run target_list, options, run_metadata) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.NotFoundError: Key decoder/embedding_rnn_seq2seq/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/weights/Adam_1 not found in checkpoint [[Node: save/RestoreV2_49 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_49/tensor_names, save/RestoreV2_49/shape_and_slices)]]

Caused by op 'save/RestoreV2_49', defined at: File "03-Twitter-chatbot.py", line 44, in sess = model.restore_last_session() File "/home/lcarnevale/git/practical_seq2seq/src/seq2seq_wrapper.py", line 164, in restore_last_session saver = tf.train.Saver() File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1051, in init self.build() File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1081, in build restore_sequentially=self._restore_sequentially) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 675, in build restore_sequentially, reshape) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 402, in _AddRestoreOps tensors = self.restore_op(filename_tensor, saveable, preferred_shard) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 242, in restore_op [spec.tensor.dtype])[0]) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 668, in restore_v2 dtypes=dtypes, name=name) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op op_def=op_def) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op original_op=self._default_original_op, op_def=op_def) File "/home/lcarnevale/git/practical_seq2seq/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1264, in init self._traceback = _extract_stack()

NotFoundError (see above for traceback): Key decoder/embedding_rnn_seq2seq/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/weights/Adam_1 not found in checkpoint [[Node: save/RestoreV2_49 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_49/tensor_names, save/RestoreV2_49/shape_and_slices)]]`

If I train my model, the prediction process is successfully completed.