suriyadeepan / practical_seq2seq

A simple, minimal wrapper for tensorflow's seq2seq module, for experimenting with datasets rapidly
http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/
GNU General Public License v3.0
570 stars 270 forks source link

I unable to test the model #55

Open bandarikanth opened 6 years ago

bandarikanth commented 6 years ago

I used this code. import tensorflow as tf import numpy as np

preprocessed data

from datasets.twitter import data import data_utils

load data from pickle and npy files

metadata, idx_q, idx_a = data.load_data(PATH='/home/kusuma/Videos/practical_seq2seq-master/datasets/twitter') (trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a)

parameters

xseq_len = testX.shape[-1] yseq_len = testY.shape[-1] batch_size = 16 xvocab_size = len(metadata['idx2w']) yvocab_size = xvocab_size emb_dim = 1024

import seq2seq_wrapper

import importlib importlib.reload(seq2seq_wrapper)

model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len, yseq_len=yseq_len, xvocab_size=xvocab_size, yvocab_size=yvocab_size, ckpt_path='/home/kusuma/Videos/practical_seq2seq-master/ckpt/twitterseq2seq_model.ckpt-11000.data-00000-of-00001', emb_dim=emb_dim, num_layers=3 )

val_batch_gen = data_utils.rand_batch_gen(validX, validY, 256)

test_batch_gen = data_utils.rand_batch_gen(testX, testY, 256)

train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)

sess = model.test(test_batch_gen)

sess = model.restore_last_session()

input_ = test_batchgen.next()[0] output = model.predict(sess, input) print(output.shape) replies = [] for ii, oi in zip(input_.T, output): q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ') decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ') if decoded.count('unk') == 0: if decoded not in replies: print('q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded))) replies.append(decoded)

I got errors: usr/bin/python3.5 /home/kusuma/Videos/practical_seq2seq-master/test.py Traceback (most recent call last): File "/home/kusuma/Videos/practical_seq2seq-master/test.py", line 42, in

Building Graph output = model.predict(sess, input_) File "/home/kusuma/Videos/practical_seq2seq-master/seq2seq_wrapper.py", line 175, in predict dec_op_v = sess.run(self.decode_outputs_test, feed_dict) AttributeError: 'NoneType' object has no attribute 'run'
PedroPei commented 6 years ago

Oh,I see.You didn't load your model to sess before you run it. You need to save your model after training using saver and load the model before you test it.