Open GraphGrailAi opened 7 years ago
When you run train.py, the labels.json will be saved. labels.json is a list with all labels.
When you run predict.py, take a look at line 63, if you print batch_predictions, it is a list with numbers, and each number is the index of labels.json.
For example, I printed batch_predictions: [6 6 6 4 6 4 3 6 4 6 3 4 1 2 3 2 3 2 4 0 4 4 4 3 4 6 4 4 1 4 0 6 2 4 4 6 3 3 1 3 4 4 3 4 3 6 3 6 6 6] the first number in batch_predictions is 6, so the corresponding label for number 6 is labels.json[6], mortgage.
Hope this will help you find the corresponding labels.
Thanks for answer, i have done guess myself, and i tested that list of prediction index labels is all_predictions (not batch_predictions). When printed batch_predictions it return empty list []
predict.py from 63 line:
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
if y_test is not None:
y_test = np.argmax(y_test, axis=1)
correct_predictions = sum(all_predictions == y_test)
logging.critical('The batch_predictions is: {}'.format(batch_predictions))
logging.critical('The all_predictions is: {}'.format(all_predictions))
logging.critical('The y_test is: {}'.format(y_test)) # y_test is label list in labels.json
logging.critical('The correct_predictions is: {}'.format(correct_predictions))
logging.critical('The accuracy is: {}'.format(correct_predictions / float(len(y_test))))
output:
d:\Django\multi-class-text-classification-cnn>python predict.py ./trained_model_1485334811/ ./data/small_samples_my.json
CRITICAL:root:Loaded the trained model: d:\Django\multi-class-text-classification-cnn\trained_model_1485334811\checkpoints\model-300
INFO:root:The number of x_test: 5
INFO:root:The number of y_test: 5
CRITICAL:root:The batch_predictions is: []
CRITICAL:root:The all_predictions is: [ 10. 10. 10. 8. 10.]
CRITICAL:root:The y_test is: [10 6 10 8 9]
CRITICAL:root:The correct_predictions is: 3
CRITICAL:root:The accuracy is: 0.6
Actually, for each batch, there will be a batch_predictions list, which will be appended to all_predictions.
Eventually, if you have 100 test examples, all predictions will have 100 numbers. Each number is the corresponding index in labels.json. You can get the actual label by referring to labels.json[index].
Thank you! i will create another issue for other question
Has anyone figured this out? I need to predict score for each of the class it predicts. Example: if the text belongs to a single class, I need to know the probability of the text belonging to that class. Any help would keep me moving.
@akki2825 Were you able to find a solution for predicting the probability of the classified text ?
@GraphGrailAi did you mean that you got accuracy for each class predict or the accuracy of the whole model ?
Has anyone got a solution on printing the probability of each sentence prediction?
Could you provide some example code how to get class output for given text input?
I was able to get all code working with ./data/small_samples.json but output is accuracy percent - i need exact class name for every text