tensorflow / skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Apache License 2.0
3.18k stars 441 forks source link

negative indices are currently unsupported in rnn example #161

Closed vinayakumarr closed 8 years ago

vinayakumarr commented 8 years ago

untitled

import random import numpy as np from sklearn import datasets from sklearn.metrics import accuracy_score, mean_squared_error

import tensorflow as tf

import skflow

random.seed(42) data = np.array(list([[2, 1, 2, 2, 3], [2, 2, 3, 4, 5], [3, 3, 1, 2, 1], [2, 4, 5, 4, 1]]), dtype=np.float32)

labels for classification

labels = np.array(list([1, 0, 1, 0]), dtype=np.float32)

targets for regression

targets = np.array(list([10, 16, 10, 16]), dtype=np.float32) test_data = np.array(list([[1, 3, 3, 2, 1], [2, 3, 4, 5, 6]]))

def input_fn(X): return tf.split(1, 5, X)

    # Classification

classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='lstm', n_classes=2, input_op_fn=inputfn) classifier.fit(data, labels) classifier.weights classifier.bias_ predictions = classifier.predict(test_data) self.assertAllClose(predictions, np.array([1, 0]))

classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='rnn', n_classes=2,input_op_fn=input_fn, num_layers=2) classifier.fit(data, labels) classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='invalid_cell_type', n_classes=2,input_op_fn=input_fn, num_layers=2) with self.assertRaises(ValueError): classifier.fit(data, labels)

Regression

regressor = skflow.TensorFlowRNNRegressor(rnn_size=2, cell_type='gru', input_op_fn=inputfn) regressor.fit(data, targets) regressor.weights regressor.bias_ predictions = regressor.predict(test_data)

ilblackdragon commented 8 years ago

This is a problem with older version of Tensorflow. Please try updating to Tensorflow 0.7+. Also note that skflow will be bundled in next release of Tensorflow. (tf.contrib.skflow)

vinayakumarr commented 8 years ago

Am using tensorflow 0.7.1 and skflow 0.1.0. Then also it is showing same error. Please have a look on the attached image untitled