HealthCatalyst / healthcareai-py

Python tools for healthcare machine learning
http://healthcare.ai
MIT License
315 stars 188 forks source link

Imbalanced classes are not always represented in the training set #437

Open Aylr opened 7 years ago

Aylr commented 7 years ago

Background

Train test split is a stochastic process. It is possible that underrepresented classes can be left out of the training set and therefore not modeled!

STR

You may need to run this a few times to see a confusion matrix with only 2 classes represented.

import healthcareai

def main():
    """Template script for using healthcareai to train a classification lr."""
    # Load the included diabetes sample data
    dataframe = healthcareai.load_diabetes()

    dataframe['ThirtyDayReadmitFLG'].replace('Y', 'SnoCones', inplace=True)
    dataframe['ThirtyDayReadmitFLG'].replace('N', 'Waffles', inplace=True)
    dataframe.loc[0:5, 'ThirtyDayReadmitFLG'] = "Omelette"
    print(dataframe['ThirtyDayReadmitFLG'].value_counts())

    dataframe.drop(['PatientID'], axis=1, inplace=True)

    classification_trainer = healthcareai.SupervisedModelTrainer(
        dataframe=dataframe,
        predicted_column='ThirtyDayReadmitFLG',
        model_type='classification',
        grain_column='PatientEncounterID',
        impute=True,
        verbose=False)

    print(dataframe.head(5))

    lr = classification_trainer.logistic_regression()
    lr.print_confusion_matrix()

    knn = classification_trainer.knn()
    knn.print_confusion_matrix()

if __name__ == "__main__":
    main()

Bad Output

Waffles     840
SnoCones    154
Omelette      6
Name: ThirtyDayReadmitFLG, dtype: int64

Note: Numeric imputation will always occur when making predictions on new data - otherwise rows would be dropped, which would lead to missing predictions.

Imputed values for numeric columns:
╒═══════════════╤═══════════════════╕
│ Column        │   Percent Imputed │
╞═══════════════╪═══════════════════╡
│ SystolicBPNBR │             0.013 │
├───────────────┼───────────────────┤
│ LDLNBR        │             0.013 │
├───────────────┼───────────────────┤
│ A1CNBR        │             0.013 │
╘═══════════════╧═══════════════════╛

   PatientEncounterID  SystolicBPNBR  LDLNBR  A1CNBR GenderFLG  \
0                   1          167.0   195.0     4.2         M   
1                   2          153.0   214.0     5.0         M   
2                   3          170.0   191.0     4.0         M   
3                   4          187.0   135.0     4.4         M   
4                   5          188.0   125.0     4.3         M   

  ThirtyDayReadmitFLG  
0            Omelette  
1            Omelette  
2            Omelette  
3            Omelette  
4            Omelette  
Training: Logistic Regression , Type: classification
LogisticRegression Training Results:
- Training time:
    LogisticRegression seconds
- Best hyperparameters found were:
    N/A: No hyperparameter search was performed
- LogisticRegression selected performance metrics:
    accuracy: 0.85
    positive_label: Waffles
    roc_auc: 0.38
    pr_auc: 0.78

Confusion Matrix (Counts)
    - Predicted Classes are along the top
    - True Classes are along the left.

            SnoCones    Waffles
--------  ----------  ---------
SnoCones           0         30
Waffles            0        170
Training: Knn , Type: classification
KNN Grid: {'n_neighbors': [5, 8, 11, 14, 17, 20, 23], 'weights': ['uniform', 'distance']}
KNeighborsClassifier Training Results:
- Training time:
    KNeighborsClassifier seconds
- Best hyperparameters found were:
    {'weights': 'distance', 'n_neighbors': 23}
- KNeighborsClassifier selected performance metrics:
    accuracy: 0.85
    positive_label: Waffles
    roc_auc: 0.23
    pr_auc: 0.74

Confusion Matrix (Counts)
    - Predicted Classes are along the top
    - True Classes are along the left.

            SnoCones    Waffles
--------  ----------  ---------
SnoCones           1         29
Waffles            0        170

Process finished with exit code 0

Good Output

Waffles     840
SnoCones    154
Omelette      6
Name: ThirtyDayReadmitFLG, dtype: int64

Note: Numeric imputation will always occur when making predictions on new data - otherwise rows would be dropped, which would lead to missing predictions.

Imputed values for numeric columns:
╒═══════════════╤═══════════════════╕
│ Column        │   Percent Imputed │
╞═══════════════╪═══════════════════╡
│ SystolicBPNBR │             0.013 │
├───────────────┼───────────────────┤
│ LDLNBR        │             0.013 │
├───────────────┼───────────────────┤
│ A1CNBR        │             0.013 │
╘═══════════════╧═══════════════════╛

   PatientEncounterID  SystolicBPNBR  LDLNBR  A1CNBR GenderFLG  \
0                   1          167.0   195.0     4.2         M   
1                   2          153.0   214.0     5.0         M   
2                   3          170.0   191.0     4.0         M   
3                   4          187.0   135.0     4.4         M   
4                   5          188.0   125.0     4.3         M   

  ThirtyDayReadmitFLG  
0            Omelette  
1            Omelette  
2            Omelette  
3            Omelette  
4            Omelette  
Training: Logistic Regression , Type: classification
LogisticRegression Training Results:
- Training time:
    LogisticRegression seconds
- Best hyperparameters found were:
    N/A: No hyperparameter search was performed
- LogisticRegression selected performance metrics:
    accuracy: 0.84

Confusion Matrix (Counts)
    - Predicted Classes are along the top
    - True Classes are along the left.

            Omelette    SnoCones    Waffles
--------  ----------  ----------  ---------
Omelette           0           0          2
SnoCones           0           0         29
Waffles            0           0        169
Training: Knn , Type: classification
KNN Grid: {'n_neighbors': [5, 8, 11, 14, 17, 20, 23], 'weights': ['uniform', 'distance']}
KNeighborsClassifier Training Results:
- Training time:
    KNeighborsClassifier seconds
- Best hyperparameters found were:
    {'weights': 'distance', 'n_neighbors': 17}
- KNeighborsClassifier selected performance metrics:
    accuracy: 0.84

Confusion Matrix (Counts)
    - Predicted Classes are along the top
    - True Classes are along the left.

            Omelette    SnoCones    Waffles
--------  ----------  ----------  ---------
Omelette           0           0          2
SnoCones           0           2         27
Waffles            0           3        166

Process finished with exit code 0
Aylr commented 7 years ago

Maybe MVP scikit stratified train test split?