HealthCatalyst / healthcareai-r

R tools for healthcare machine learning
https://docs.healthcare.ai
Other
245 stars 106 forks source link

prep_data factor-reference-level handling #1232

Closed glenrs closed 6 years ago

glenrs commented 6 years ago

Step_dummy didn't have the functionality that we were looking for. To add reference levels I needed to convert all character variables to factors before calling step_dummy. The other option that I see is to create another class that inherits Step_dummy. I will look into this if you would like, but I think the functions that I have created are the simplest implementation.

The only caveat is that we will need to call set_refs before using the recipe if dummies are created. If you want to see how I did this, refer to ready_with_prep.

I also added two functionalities to interpret. 1- the graph now has the reference variable included if it exists. 2- I added a print.interpret function that prints out the associated reference levels.

I will add some screen shots before requesting any reviews.

glenrs commented 6 years ago
m <- machine_learn(pima_diabetes, outcome = diabetes, models = "glm", tune = FALSE)
inter <- interpret(m)

inter
plot(inter)

screen shot 2018-08-23 at 3 05 53 pm

ref_levels

glenrs commented 6 years ago
library(healthcareai)
#> healthcareai version 2.1.1
#> Please visit https://docs.healthcare.ai for full documentation and vignettes. Join the community at https://healthcare-ai.slack.com
(prepped_data <- prep_data(pima_diabetes, outcome = diabetes, ref_levels = c(weight_class = "normal")))
#> Training new data prep recipe...
#> healthcareai-prepped data. Recipe used to prepare data:
#> Data Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>    outcome          1
#>  predictor          9
#> 
#> Training data contained 768 data points and 376 incomplete rows. 
#> 
#> Operations:
#> 
#> Sparse, unbalanced variable filter removed no terms [trained]
#> Mean Imputation for patient_id, pregnancies, ... [trained]
#> Filling NA with missing for weight_class [trained]
#> Adding levels to: other, missing [trained]
#> Collapsing factor levels for weight_class [trained]
#> Adding levels to: other, missing [trained]
#> Dummy variables from weight_class [trained]
#> Current data:
#> # A tibble: 768 x 14
#>    patient_id pregnancies plasma_glucose diastolic_bp skinfold insulin
#>         <int>       <int>          <dbl>        <dbl>    <dbl>   <dbl>
#>  1          1           6            148         72       35      156.
#>  2          2           1             85         66       29      156.
#>  3          3           8            183         64       29.2    156.
#>  4          4           1             89         66       23       94 
#>  5          5           0            137         40       35      168 
#>  6          6           5            116         74       29.2    156.
#>  7          7           3             78         50       32       88 
#>  8          8          10            115         72.4     29.2    156.
#>  9          9           2            197         70       45      543 
#> 10         10           8            125         96       29.2    156.
#> # ... with 758 more rows, and 8 more variables: pedigree <dbl>, age <int>,
#> #   diabetes <fct>, weight_class_morbidly.obese <dbl>,
#> #   weight_class_obese <dbl>, weight_class_overweight <dbl>,
#> #   weight_class_other <dbl>, weight_class_missing <dbl>
(models <- flash_models(prepped_data, outcome = diabetes))
#> 
#> diabetes looks categorical, so training classification algorithms.
#> 
#> After data processing, models are being trained on 13 features with 768 observations.
#> Based on n_folds = 5 and hyperparameter settings, the following number of models will be trained: 5 rf's, 5 xgb's, and 50 glm's
#> Training at fixed values: Random Forest
#> Training at fixed values: eXtreme Gradient Boosting
#> Training at fixed values: glmnet
#> 
#> *** Models successfully trained. The model object contains the training data minus ignored ID columns. ***
#> *** If there was PHI in training data, normal PHI protocols apply to the model object. ***
#> Algorithms Trained: Random Forest, eXtreme Gradient Boosting, and glmnet
#> Model Name: diabetes
#> Target: diabetes
#> Class: Classification
#> Performance Metric: AUROC
#> Number of Observations: 768
#> Number of Features: 13
#> Models Trained: 2018-08-23 17:23:22 
#> 
#> Models have not been tuned. Performance estimated via 5-fold cross validation at fixed hyperparameter values.
#> Best model: Random Forest
#> AUPR = 0.7, AUROC = 0.84
#> User-selected hyperparameter values:
#>   mtry = 3
#>   splitrule = extratrees
#>   min.node.size = 1
interpret(models)
#> Warning in interpret(models): Interpreting glmnet model, but Random Forest
#> performed best in cross-validation and will be used to make predictions. To
#> use the glmnet model for predictions, extract it with x['glmnet'].
#> Reference Levels:
#> All `weight_class` are relative to `normal`
#> All `diabetes` are relative to `N`
#> 
#> 
#> The coeffients and reference levels of each variable:
#> # A tibble: 13 x 3
#>    variable                    coefficient reference_level
#>  * <chr>                             <dbl> <chr>          
#>  1 (Intercept)                  -7.66      <NA>           
#>  2 weight_class_morbidly.obese   1.84      normal         
#>  3 weight_class_obese            1.64      normal         
#>  4 pedigree                      0.821     <NA>           
#>  5 weight_class_overweight       0.745     normal         
#>  6 weight_class_other            0.532     normal         
#>  7 pregnancies                   0.108     <NA>           
#>  8 plasma_glucose                0.0355    <NA>           
#>  9 age                           0.0109    <NA>           
#> 10 skinfold                      0.00886   <NA>           
#> 11 insulin                      -0.000366  <NA>           
#> 12 patient_id                   -0.000326  <NA>           
#> 13 diastolic_bp                 -0.0000365 <NA>
interpret(models) %>% plot()
#> Warning in interpret(models): Interpreting glmnet model, but Random Forest
#> performed best in cross-validation and will be used to make predictions. To
#> use the glmnet model for predictions, extract it with x['glmnet'].

Created on 2018-08-23 by the reprex package (v0.2.0).

glenrs commented 6 years ago

@michaellevy, thank you for taking the time to review this for me. I know you have a lot going on before tomorrow. I made those major edits. I am sure there are some small things that need fixing. This is the best I can do for right now. I do think that it is a major improvement, and I am proud to say it is my work.

I probably won't be able to make any more edits tomorrow. Hopefully this is close!

library(healthcareai)
#> healthcareai version 2.1.1
#> Please visit https://docs.healthcare.ai for full documentation and vignettes. Join the community at https://healthcare-ai.slack.com
(prepped_data <- prep_data(pima_diabetes, outcome = diabetes, ref_levels = c(weight_class = "normal")))
#> Training new data prep recipe...
#> healthcareai-prepped data. Recipe used to prepare data:
#> Data Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>    outcome          1
#>  predictor          9
#> 
#> Training data contained 768 data points and 376 incomplete rows. 
#> 
#> Operations:
#> 
#> Sparse, unbalanced variable filter removed no terms [trained]
#> Mean Imputation for patient_id, pregnancies, ... [trained]
#> Filling NA with missing for weight_class [trained]
#> Adding levels to: other, missing [trained]
#> Collapsing factor levels for weight_class [trained]
#> Adding levels to: other, missing [trained]
#> Dummy variables from weight_class [trained]
#> Current data:
#> # A tibble: 768 x 14
#>    patient_id pregnancies plasma_glucose diastolic_bp skinfold insulin
#>         <int>       <int>          <dbl>        <dbl>    <dbl>   <dbl>
#>  1          1           6            148         72       35      156.
#>  2          2           1             85         66       29      156.
#>  3          3           8            183         64       29.2    156.
#>  4          4           1             89         66       23       94 
#>  5          5           0            137         40       35      168 
#>  6          6           5            116         74       29.2    156.
#>  7          7           3             78         50       32       88 
#>  8          8          10            115         72.4     29.2    156.
#>  9          9           2            197         70       45      543 
#> 10         10           8            125         96       29.2    156.
#> # ... with 758 more rows, and 8 more variables: pedigree <dbl>, age <int>,
#> #   diabetes <fct>, weight_class_morbidly.obese <dbl>,
#> #   weight_class_obese <dbl>, weight_class_overweight <dbl>,
#> #   weight_class_other <dbl>, weight_class_missing <dbl>
(models <- flash_models(prepped_data, outcome = diabetes))
#> 
#> diabetes looks categorical, so training classification algorithms.
#> 
#> After data processing, models are being trained on 13 features with 768 observations.
#> Based on n_folds = 5 and hyperparameter settings, the following number of models will be trained: 5 rf's, 5 xgb's, and 50 glm's
#> Training at fixed values: Random Forest
#> Training at fixed values: eXtreme Gradient Boosting
#> Training at fixed values: glmnet
#> 
#> *** Models successfully trained. The model object contains the training data minus ignored ID columns. ***
#> *** If there was PHI in training data, normal PHI protocols apply to the model object. ***
#> Algorithms Trained: Random Forest, eXtreme Gradient Boosting, and glmnet
#> Model Name: diabetes
#> Target: diabetes
#> Class: Classification
#> Performance Metric: AUROC
#> Number of Observations: 768
#> Number of Features: 13
#> Models Trained: 2018-08-24 04:34:30 
#> 
#> Models have not been tuned. Performance estimated via 5-fold cross validation at fixed hyperparameter values.
#> Best model: glmnet
#> AUPR = 0.71, AUROC = 0.83
#> User-selected hyperparameter values:
#>   alpha = 1
#>   lambda = 0.0027
interpret(models)
#> Reference Levels:
#> All `weight_class` estimates are relative to `normal`
#> 
#> # A tibble: 13 x 2
#>    variable                                 coefficient
#>  * <chr>                                          <dbl>
#>  1 (Intercept)                               -7.66     
#>  2 weight_class_morbidly.obese (vs. normal)   1.84     
#>  3 weight_class_obese (vs. normal)            1.64     
#>  4 pedigree                                   0.821    
#>  5 weight_class_overweight (vs. normal)       0.745    
#>  6 weight_class_other (vs. normal)            0.532    
#>  7 pregnancies                                0.108    
#>  8 plasma_glucose                             0.0355   
#>  9 age                                        0.0109   
#> 10 skinfold                                   0.00886  
#> 11 insulin                                   -0.000366 
#> 12 patient_id                                -0.000326 
#> 13 diastolic_bp                              -0.0000365
interpret(models) %>% plot()

Created on 2018-08-24 by the reprex package (v0.2.0).

codecov[bot] commented 6 years ago

Codecov Report

Merging #1232 into master will decrease coverage by <.1%. The diff coverage is 98.8%.

@@           Coverage Diff            @@
##           master   #1232     +/-   ##
========================================
- Coverage    94.5%   94.4%   -0.1%     
========================================
  Files          39      37      -2     
  Lines        2797    2664    -133     
========================================
- Hits         2645    2517    -128     
+ Misses        152     147      -5
codecov[bot] commented 6 years ago

Codecov Report

Merging #1232 into master will decrease coverage by <.1%. The diff coverage is 98.8%.

@@           Coverage Diff            @@
##           master   #1232     +/-   ##
========================================
- Coverage    94.5%   94.4%   -0.1%     
========================================
  Files          39      37      -2     
  Lines        2802    2664    -138     
========================================
- Hits         2650    2517    -133     
+ Misses        152     147      -5