jiegzhan / multi-class-text-classification-cnn

Classify Kaggle Consumer Finance Complaints into 11 classes. Build the model with CNN (Convolutional Neural Network) and Word Embeddings on Tensorflow.
Apache License 2.0
426 stars 198 forks source link

I wish to get raw probability distribution for the predicted classes. facing problems in using softmax #26

Open ctsaurabhs opened 6 years ago

ctsaurabhs commented 6 years ago

please help me, as how can I use softmax in place of argmax() to get raw probability distribution for the predicted classes

changukshin commented 6 years ago

The predictions operation is passed to process to the tensorflow session at line 63 of predict.py. If you want to get the probability distribution of the classes, you can pass the operation you want to proceed. The predictions operation is assigned at line 58 of predict.py. Please search the probability operation of the network from text_cnn.py and pass the operation with predictions.

like

predictions = graph.get_operation_by_name("output/predictions").outputs[0]
scores = graph.get_operation_by_name("output/scores").outputs[0]

batches = ...
all_predictions = []
all_probabilities = []
for x_test_batch in batches:
    batch_predictions, batch_scores = sess.run([predictions, scores], {input_x: x_test_batch, dropout_keep_prob: 1.0})
    all_predictions = np.concatenate([all_predictions, batch_predictions])
    all_scores = np.concatenate([all_scores, batch_scores])

Note that above code is not tested. Please take a look at the predict.py line 58 ~ 64.