sergioburdisso / pyss3

A Python package implementing a new interpretable machine learning model for text classification (with visualization tools for Explainable AI :octocat:)
https://pyss3.readthedocs.io
MIT License
333 stars 44 forks source link

Change of category name #8

Closed angrymeir closed 4 years ago

angrymeir commented 4 years ago

Description

The category names are changed in the learning process, this results in a mismatch between predicted category names and true category names.

Example

text = ["Document 1", "Document 2"]
groundtruth = ["Label 1", "Label 2"]

clf = SS3()
clf.fit(doc, groundtruth)

y_pred = clf.predict(doc)
print(y_pred) #["label 1", "label 2"]

Explanation

While training the categories are modified by .lower() here. When calling .predict() the modified labels are returned here.

Why is this a problem

When calling .predict() with parameter labels=True (the default setting), the predicted category names have to be postprocessed for a direct comparison to the true category names.

Fix

Remove .lower() :) However, I'm not entirely sure about the consequences for the rest of the project.

sergioburdisso commented 4 years ago

Hi @angrymeir

yes, you're totally right, actually, I don't know why this "lower()ing" thing was added in the first place, I think it was added when this project was a prototype and was mostly used using the PySS3 Command Line Tool, to make things easier for the user while typing the category names.

But now, it makes no sense to automatically convert category names to lower case, it should be a user's decision not pyss3's. Once I finish working on Issue #5 I'll remove the lower() as you suggest and make sure that it does not negatively affect other parts of the library before releasing the new version (0.6.0) which will fully support multilabel classification. Speaking of which, I've just finished adding multilabel support to the Evaluation.test() (0a897dd), it now supports Hamming Loss metric along with all previous ones and also plots a binary confusion matrix for each possible label.