jiegzhan / multi-class-text-classification-cnn

Classify Kaggle Consumer Finance Complaints into 11 classes. Build the model with CNN (Convolutional Neural Network) and Word Embeddings on Tensorflow.
Apache License 2.0
426 stars 198 forks source link

How to get class for given text input? #1

Open GraphGrailAi opened 7 years ago

GraphGrailAi commented 7 years ago

Could you provide some example code how to get class output for given text input?

I was able to get all code working with ./data/small_samples.json but output is accuracy percent - i need exact class name for every text

jiegzhan commented 7 years ago
  1. When you run train.py, the labels.json will be saved. labels.json is a list with all labels.

  2. When you run predict.py, take a look at line 63, if you print batch_predictions, it is a list with numbers, and each number is the index of labels.json.

For example, I printed batch_predictions: [6 6 6 4 6 4 3 6 4 6 3 4 1 2 3 2 3 2 4 0 4 4 4 3 4 6 4 4 1 4 0 6 2 4 4 6 3 3 1 3 4 4 3 4 3 6 3 6 6 6] the first number in batch_predictions is 6, so the corresponding label for number 6 is labels.json[6], mortgage.

Hope this will help you find the corresponding labels.

GraphGrailAi commented 7 years ago

Thanks for answer, i have done guess myself, and i tested that list of prediction index labels is all_predictions (not batch_predictions). When printed batch_predictions it return empty list []

predict.py from 63 line:

            for x_test_batch in batches:
                batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
                all_predictions = np.concatenate([all_predictions, batch_predictions])

    if y_test is not None:
        y_test = np.argmax(y_test, axis=1)
        correct_predictions = sum(all_predictions == y_test)
        logging.critical('The batch_predictions is: {}'.format(batch_predictions))
        logging.critical('The all_predictions is: {}'.format(all_predictions))
        logging.critical('The y_test is: {}'.format(y_test)) # y_test is label list in labels.json
        logging.critical('The correct_predictions is: {}'.format(correct_predictions))
        logging.critical('The accuracy is: {}'.format(correct_predictions / float(len(y_test))))

output:

d:\Django\multi-class-text-classification-cnn>python predict.py ./trained_model_1485334811/ ./data/small_samples_my.json
CRITICAL:root:Loaded the trained model: d:\Django\multi-class-text-classification-cnn\trained_model_1485334811\checkpoints\model-300
INFO:root:The number of x_test: 5
INFO:root:The number of y_test: 5
CRITICAL:root:The batch_predictions is: []
CRITICAL:root:The all_predictions is: [ 10.  10.  10.   8.  10.]
CRITICAL:root:The y_test is: [10  6 10  8  9]
CRITICAL:root:The correct_predictions is: 3
CRITICAL:root:The accuracy is: 0.6
jiegzhan commented 7 years ago

Actually, for each batch, there will be a batch_predictions list, which will be appended to all_predictions.

Eventually, if you have 100 test examples, all predictions will have 100 numbers. Each number is the corresponding index in labels.json. You can get the actual label by referring to labels.json[index].

GraphGrailAi commented 7 years ago

Thank you! i will create another issue for other question

akki2825 commented 7 years ago

Has anyone figured this out? I need to predict score for each of the class it predicts. Example: if the text belongs to a single class, I need to know the probability of the text belonging to that class. Any help would keep me moving.

vijaysaimutyala commented 6 years ago

@akki2825 Were you able to find a solution for predicting the probability of the classified text ?

@GraphGrailAi did you mean that you got accuracy for each class predict or the accuracy of the whole model ?

Chinguun8 commented 5 years ago

Has anyone got a solution on printing the probability of each sentence prediction?