gregversteeg / corex_topic

Hierarchical unsupervised and semi-supervised topic models for sparse count data with CorEx
Apache License 2.0
627 stars 120 forks source link

model_ct.predict_proba() explanation #44

Closed sachindevsharma closed 4 years ago

sachindevsharma commented 4 years ago

Hi Greg,

I was trying corextopic for supervised topic modeling (more precisely classification) and was using the model.predict_proba(). This gives me output something similar to (array([[0.999, 0.0022]]), array([[0.198, -0.205]]). Could you please explain what these values are. That will be a great help.

Thanks in advance.

zafercavdar commented 4 years ago

Hi Sachin,

According to documentation, predict_proba function calls transform function with details=True parameter. When details are enabled, transform function returns a tuple of p_y_given_x (probabilities) and log_z (point-wise estimate of total correlation). Your output tells us that your input document is highly related to the first topic (at 0'th index).

ryanjgallagher commented 4 years ago

Thank you for the explanation @zafercavdar! Yes, it returns a matrix where the rows are documents and the columns are probabilities of topics for those documents. What p_y_given_x refers to is the probability of a topic given the words in a document (Y = topic, X = document's words).

I'll add that the code itself is a little unclear; I believe it actually returns log_p_y_given_x, which is why there are negative values. If you did np.exp(p_y_given_x), then you should get the raw probabilities, though those may underflow depending on the size of your data.

aditya-malte commented 3 years ago

Hi, So at the moment is it returning the raw probabilities or a logarithm, because I'm getting all values as positive up to 1.

ryanjgallagher commented 3 years ago

Ah, I looked too quickly at the code. It returns the raw probabilities P(Y | X), not the log. The negative values in the original question are from the second output log_z, the pointwise total correlations, which I also looked too quickly at and so I thought they were log probabilities.

Thanks for catching this, and apologies for the confusion!