nfmcclure / tensorflow_cookbook

Code for Tensorflow Machine Learning Cookbook
https://www.packtpub.com/big-data-and-business-intelligence/tensorflow-machine-learning-cookbook-second-edition
MIT License
6.23k stars 2.41k forks source link

How to use non-linear multi-class SVM to predict class for new data? #149

Open AlexHMJ opened 6 years ago

AlexHMJ commented 6 years ago

I have modified "ch4 - Implementing Multiclass SVMs" code to use my own data set to train the classifier. The training process is good and so for the testing result. But I got some problems when I want to predict the new data that are without the labels. I see three lines of code in "ch4 - Implementing Multiclass SVMs" which use those lines below to estimate the training acc :

prediction_output = tf.matmul(tf.multiply(y_target, b), pred_kernel) prediction = tf.argmax(prediction_output - tf.expand_dims(tf.reduce_mean(prediction_output, 1), 1), 0) accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, tf.argmax(y_target, 0)), tf.float32))

  1. How do I use this trained SVM model to predict the new data (no label)?
  2. It seems that I need the label for the data to run the prediction, but I think it is very weird why I need the y_target (label) to calculate the prediction result?
  3. How can those three lines of code get correct prediction result?

Hope someone can help me to figure out what's going on.

ArrowYL commented 6 years ago

I am really confused too.so I post a question in issue #148 I am working on it.May be we can solve it from SVM theory.

nfmcclure commented 6 years ago

Hi @AlexHMJ and @ArrowYL ,

Thanks for bringing this up. I'm quite busy in the next month, but I can check this out and see if I can extend it to the MNIST data in a few weeks. Let me know if you make any progress in the mean time.

klchang commented 5 years ago
# Predict one new sample
new_sample = np.array([6.5, 1.0]).reshape(-1,2)
pred = sess.run(prediction, feed_dict={x_data: rand_x, y_target: rand_y, prediction_grid: new_sample})
print("predicted: {}".format(pred[0]))

In my humble opinion, the naming 'y_target' in the prediction part is a little confusing, because its meanings in 'prediction_output' and 'accuracy' may be different: the former represents the target of training data, but the latter may represent the target of training data or that of test data.

anbo1024 commented 5 years ago

I have encountered the same problem with a test accuracy of 100%. Is this problem solved by you?