uoy-research / CellPhePy

Python port of the CellPhe R package for extracting features from time-lapse cell images
MIT License
0 stars 0 forks source link

Add Multi-class classification #93

Open stulacy opened 1 week ago

stulacy commented 1 week ago

Currently classify_cells only works for a binary classification problem. It would be useful to make this more generic to handle an arbitrary multi-class problem. Following the discussion in #90 , the idea would be to predict probabilities for each class rather than the current hard label method. This would have 2 advantages: firstly it makes combining the ensemble members' predictions easier, and secondly it allows for more granularity i.e. so that a predicted class is only returned if it reaches a certain probability threshold.

stulacy commented 1 day ago

I've just realised that what I said about classify_cells only working for binary data is wrong. classify_cells does indeed handle a multi-class dataset.

It currently uses the hard voting method (see the sklearn docs). This means that each of our 3 models predicts a hard label with a simple majority vote winning.

If we were to use a soft voting method then by default the predicted class is the one with the highest average probability. We could also extract the predicted average probabilities for each class and do the thresholding we mentioned in #90. However, this thresholding could produce unexpected results for the user. I.e. say they want to predict 5 cells and they get back 3 labels and 2 NaNs, it's not exactly useful.

Instead if we're going with a more granular approach with probabilities then I think it would be more useful to return the full predicted probability matrix of N_rows x N_classes columns. Then the user can do the thresholding themselves, or just take the highest row-wise probability.

The only other consideration is that to use soft voting methods you ideally want well calibrated models. Given that random forests and SVMs aren't probabilistic models, I don't understand where their predicted probabilities come from and thus how much faith we can have in them.

In my mind then the 3 options are :

  1. Use the majority vote from hard voting as we're currently doing
  2. Use the highest average probability from soft voting
  3. Return the average predicted probabilities for all classes
llwiggins commented 1 day ago

Hi @stulacy, thank you for taking the time to consider our options!

I think it would be best to go with option 2 and use the highest average probability from soft voting. That way the user will still receive a classification as expected, but can filter these out based on their own requirements. We'll have to add this to the documentation so that users are fully aware of the output and understand that in some cases, although they will receive a final classification, the classification that is made may not be trustworthy as its probability is still very low!