EpistasisLab / Aliro

Aliro: AI-Driven Data Science
https://epistasislab.github.io/Aliro
GNU General Public License v3.0
224 stars 63 forks source link

Handling NaN Values in TPR Calculation for Highly Unbalanced Datasets #643

Open HyunjunA opened 7 months ago

HyunjunA commented 7 months ago

The machine learning backend encounters a breakdown related to handling datasets with extreme class imbalances. Currently addressing all issues related to this problem in the machine learning backend.

HyunjunA commented 7 months ago

The default cross-validation (cv) value of the current ML backend is 10, which could potentially pose issues for certain types of datasets. Specifically, when utilizing the CSICU team's dataset with cv = 10 and any form of Stratified cross-validation methods, this issue remains unresolved. As a temporary solution, I've employed 'nanmean' to calculate the mean accuracy. However, a permanent solution requires the implementation of logic to set the cv appropriately for edge cases.

HyunjunA commented 6 months ago

The following code adjusts the number of splits based on the smallest class size within a given dataset. In the machine learning backend, the use of stratified cross-validation function ensures that each fold contains at least one instance of every class, with the number of splits determined by the size of the smallest class. This approach guarantees representation of all classes across folds.

However, for the highly imbalanced and small dataset provided by Debbie, we has opted for a temporary solution. This solution involves using the nanmean function to prevent the generation of NaN values.

We will discuss the possibility of integrating the below code into Aliro and better solutions in the future.

def decision_rule_fold_cv_based_on_classes(each_class):
    """
    Adjusts the number of cross-validation folds based on the class distribution.

    Parameters
    ----------
    each_class : dict
        A dictionary where keys are the classes and the values are the number of samples per class.

    Returns
    -------
    cv : int
        The suitable number of cross-validation folds ensuring that each fold can include instances of each class.
     """
    # Find the minimum class count to ensure every fold can contain at least one instance of every class.
    min_class_count = min(each_class.values())

    # The maximum number of folds is determined by the smallest class to ensure representation in each fold.
    # However, we cannot have more folds than the minimum class count.
    n_folds = min(10, min_class_count)  # Starting with a default max of 10 folds

    # Ensure at least 2 folds for meaningful cross-validation.
    n_folds = max(n_folds, 2)

    return n_folds