jwcalder / GraphLearning

Python package for graph-based clustering and semi-supervised learning
MIT License
85 stars 26 forks source link

`fit` method in `ssl` class #9

Closed gmeng92 closed 6 months ago

gmeng92 commented 6 months ago

Hi,

I am using the ssl model to predict the labels for unlabeled data. The fit_predict method works generally well on my case and I am trying to see if I can get a score for each of the classes. I noticed that in ssl.fit() method, it is supposed to return a $(n,k)$ numpy array of probabilities. However, I keep getting positive and negative values in a binary classification, and the same thing happened for multi classification problem as well. So maybe that is not the true probability? I am curious about the explanation of the output u in fit method.

Thanks,

Gemeng

jwcalder commented 6 months ago

Hi Gemeng,

I'm curious about which SSL methods you're using in the package? What gets returned depends on the SSL method.

I noticed the documentation is not quite correct here. The (n,k) array contains length k score vectors for each node. For some methods, like Laplace learning, these are in fact probabilities, but for other methods they are not (such as Poisson learning). Usually the label prediction is the argmax along the rows of u, but for some methods it is argmin (for distance-based classifiers, like nearest neighbors, where the scores are the graph distance to that class).

I've updated the documentation to reflect this. Let me know if this makes sense or if you have any other questions.

PS: You can always run u through a row-wise softmax to convert to probabiltiies.

Best, Jeff

gmeng92 commented 6 months ago

Thanks for the answer Jeff! I am using the Poisson learning in my case and I found the scores returned by the fit method is some positive/negative values. Thanks again for clarifying the different scores for different SSL methods. I will close the issues then.