tensorflow / skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Apache License 2.0
3.18k stars 439 forks source link

input must be a list error in skflow lstm #159

Closed vinayakumarr closed 8 years ago

vinayakumarr commented 8 years ago

import random import pandas as pd import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from sklearn.utils import check_array from sklearn.cross_validation import train_test_split from scipy.cluster.vq import whiten

import tensorflow as tf

import skflow

print("program started to execute")

train_data = pd.read_csv("cla/train.csv") class_values = pd.read_csv("cla/classnames.csv") test_data = pd.read_csv("cla/test.csv")

print len(test_data)

EMBEDDING_SIZE = 50

Single direction GRU with a single layer

classifier = skflow.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE, n_classes=32, cell_type='gru', num_layers=1, bidirectional=False, sequence_length=None, steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True)

Continously train for 1000 steps & predict on test set.

while True: classifier.fit(train_data, class_values, logdir='logs/kdd') score = classifier.predict(test_data) print(score)

training dataset is 3 3 689 18 6 890 6 39 17 88 0 218

class names 1 to 9

testing dataset 10 10 137 18 6 90 6 27 0 31 0 9

i have added only few data here.....class names from 1 to 9 generates an error please see the attached file untitled

nhilliard commented 8 years ago

You're passing in a Panda DataFrame whereas the fit(...) function is expecting a python list.

You can get a list of all of your rows in the CSV (a list of lists) by doing train_data.values.tolist(), class_values.values.tolist(), test_data.values.tolist()

ilblackdragon commented 8 years ago

You should add an input_fn to RNNClassifier that would do split_squeeze on your inputs before giving them to RNN. For more detail - see this example.

ilblackdragon commented 8 years ago

Closing this. If there is more questions - please use StackOverflow or tensorflow/tensorflow issues.