bgreenwell / pdp

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
http://bgreenwell.github.io/pdp
93 stars 12 forks source link

From Keras3 to package pdp in R #134

Open Oakwilde opened 1 month ago

Oakwilde commented 1 month ago

Might pdp be extended to work directly with output from Keras3? Also, variable importance plots would be a nice addition. There are number of related ideas in Molnar’s nice survey book “Interpretable Machine Learning…” 2022. Some are already in R for other ML procedures (e.g., for gbm).

brandongreenwell-8451 commented 1 month ago

Hi @Oakwilde, pdp can already work with ANY predictive model, provided you include a suitable prediction wrapper via the pred.fun argument. Similar for package vip which already implements many of the ideas in Molnar's book. If you have a brief reprex, I'd be happy to help pull together working examples?

brandongreenwell-8451 commented 1 month ago

An example is given in the Classification problems section of this vignette.

Oakwilde commented 1 month ago

Hi. Thanks for the response. I just learned about the wrapper option. And the suggestion about vip. I will give both a try. There can be substantial differences in ease of use and results between built in ML interpretation options such as those provided directly through procedures like randomForest or gbm and those handled by the pdp package, which also spill over into the plot procedure used. For example, some classification procedures want Y to be numeric and some are quite happy with categorical classes. Some want the training data as a data frame and some want the training data as a matrix. Some carry forward most of the operators used in R and some replace many of those with new operators. These are not deep problems for users, but they requires a lot of mental bookkeeping. Some issues, however, require some serious digging such as how predictors that are factors are handled. Also, of late, journals (e.g., Springer journals) are imposing more demanding requirements on how a plot should look, which sometimes leads to more control over plotting than routinely available.

You and your colleagues are the experts in all this. I happily defer to you on what, if anything, should be or might be done. But I appreciate your offer of help with examples.

Thanks.

Richard

From: Brandon Greenwell @.> Date: Tuesday, October 8, 2024 at 12:11 PM To: bgreenwell/pdp @.> Cc: Berk, Richard A @.>, Mention @.> Subject: Re: [bgreenwell/pdp] From Keras3 to package pdp in R (Issue #134)

Hi @Oakwildehttps://urldefense.com/v3/__https:/github.com/Oakwilde__;!!IBzWLUs!TyaGdWLFDAIfLMDmudaqFO4kCn9xCqFLhx3JxLgSQG2qbhmi0Y_d6A8BCJ8Wrrelp4uDECpZjqynhuvModTkM0Br3Q$, pdp can already work with ANY predictive model, provided you include a suitable prediction wrapper via the pred.fun argument. Similar for package vip which already implements many of the ideas in Molnar's book. If you have a brief reprex, I'd be happy to help pull together working examples?

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https:/github.com/bgreenwell/pdp/issues/134*issuecomment-2400269412__;Iw!!IBzWLUs!TyaGdWLFDAIfLMDmudaqFO4kCn9xCqFLhx3JxLgSQG2qbhmi0Y_d6A8BCJ8Wrrelp4uDECpZjqynhuvModT3mFWx8w$, or unsubscribehttps://urldefense.com/v3/__https:/github.com/notifications/unsubscribe-auth/AHFKAWSO4F63Z2FMZMH5LADZ2P72VAVCNFSM6AAAAABPSNBCZOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBQGI3DSNBRGI__;!!IBzWLUs!TyaGdWLFDAIfLMDmudaqFO4kCn9xCqFLhx3JxLgSQG2qbhmi0Y_d6A8BCJ8Wrrelp4uDECpZjqynhuvModSPOXllQw$. You are receiving this because you were mentioned.Message ID: @.***>

brandongreenwell-8451 commented 1 month ago

@Oakwilde I completely sympathize with your callouts! These are precisely why I initially wrote these packages. I've tried to maintain a balance between explicit support for various package, but ultimately, users will find the best experience in understanding how the optional prediction wrapper in each package works, as this will allow you to work across different packages, data types (e.g., arrays, data frames, or whatever the specific model package requires). My only real regret was developing these packages too far apart (i.e., pdp, vip, and fastshap) that the naming convention is not terribly consistent between them.

They plotting utilities in these libraries are also good, but like you said, users often have specific requirements and it's helpful to override the plotting. For instance, in pdp, you can simply set plot = FALSE (can't remember the default) to return a data frame that can be plotted in whichever library is most convenient. Our book, which we're revising here at some point, offers some additional examples of supporting other models via custom prediction wrappers.

Each of these were published in the R Journal, so you can also find mroe extensive details and examples in those articles.

In any case, feel free to provide some examples when you get a chance and I'd be more than happy to help integrate them with these packages.

Oakwilde commented 1 month ago

Hi Brandon,

Here is some code I wrote today with some extensive coaching from Chatgpt. There were two issues to be overcome. The “errors” may well be mine.

  1. My outcome variable is binary: there was a heat wave = 1, there was no heat wave = 0. I expected that like in most classifiers, you either get a choice p(Y=1) or p(Y=0) or you get p(Y=1). With activation = “sigmoid” you apparently get p(Y=0). I had never seen that before and it took me some detective work to determine what was going on. Then the apparent fix was trivial.
  2. At least in the code I wrote, I kept getting in partial ice plots ignoring the default or even a plain instruction not to return ice plots. Neither I nor chatgpt could figure out directly how to override that. So we wrote a few lines of “post-processing” code. Good thing I know what an ice plot is (The idea was hatched and code written here at Penn). Chatgpt said I should pass both concerns along. Maybe they are nothing or maybe they are not nothing.

Still unsolved is how in ggplot to get a rug plot for the target predictor in its original units (degrees K) not the standardized units (Z-scores) used in fitting. The math is trivial.

Richard

############ Version Using pdp and iml ##################

Get Data

setwd("/Volumes/Data in Recent Use/NewJPL/ForJune21") load("/Volumes/Data in Recent Use/NewJPL/ForJune21/AIRS.06/JuneJulyForecasting.rdata") summary(working1) work1<-working1[,c(1,6,15,20)]

Scale the predictors or in back propagation there are computation problems.

x_train <- scale(as.matrix(work1[,c(2,3,4)])) y_train <- as.matrix(work1[,1])

Scale the predictors or in back propagation there are computation problems.

This is fake test data through sampling for the training data. OK for now.

index <- sample(260, 100, replace = TRUE) x_test <- scale(as.matrix(x_train[index,])) y_test <- as.matrix(y_train[index]) TestData<-(cbind(y_test,x_test))

load library

library(keras3)

Fit a Keras Model (works well) except for the sigmoid activation problem

model <- keras_model_sequential() %>% layer_dense(units = 128, activation = 'relu', input_shape = c(ncol(x_train))) %>% layer_dense(units = 64, activation = 'relu') %>% layer_dense(units = 1, activation = 'sigmoid') # Seems to output p(Y=0) not p(Y=1)

model %>% compile( loss = 'binary_crossentropy', optimizer = optimizer_adam(), metrics = c('accuracy') )

model %>% fit( x_train, y_train, epochs = 10, batch_size = 32, validation_split = 0.2 )

Load libraries

library(pdp) library(ggplot2) library(iml) library(dplyr)

Define the custom prediction function for Keras model

predict_keras <- function(object, newdata) {

Convert newdata (data frame) to matrix for Keras model predictions

newdata_matrix <- as.matrix(newdata)

Ensure that the matrix has the correct number of columns (3 in this case)

if (ncol(newdata_matrix) != 3) { stop("Input data does not have the required number of features (3)") }

Get predictions from the model using a predictor matrix

predictions <- object %>% predict(newdata_matrix)

Return the predicted probabilities for fitting class "1"

return(as.vector(1 - predictions)) # Subtracting from 1 to get for class "1" }

Convert x_test to a data frame .

feature_data <- as.data.frame(x_test) # Ensure x_test is now a data frame

Generate the partial dependence data (which includes ICE by default)

pdp_keras <- partial( object = model, # Your Keras model pred.var = "temp8", # The feature for which you want the PDP pred.fun = predict_keras, # Custom prediction function for the Keras model train = feature_data, # Ensure x_test is now a data frame grid.resolution = 50, # Resolution of the grid (optional, can be adjusted) plot = FALSE # Disable automatic plotting )

A Fix: “Manually” aggregate the ICE data by averaging yhat over temp8

pdp_avg <- aggregate(yhat ~ temp8, data = pdp_keras, FUN = mean)

Plot the averaged PDP using ggplot

ggplot(pdp_avg, aes(x = temp8, y = yhat)) + geom_line() + labs(title = "Partial Dependence Plot for temp8", x = "temp8", y = "Predicted Probability for Class 1") + theme_minimal()

From: Brandon Greenwell @.> Date: Tuesday, October 8, 2024 at 1:48 PM To: bgreenwell/pdp @.> Cc: Berk, Richard A @.>, Mention @.> Subject: Re: [bgreenwell/pdp] From Keras3 to package pdp in R (Issue #134)

@Oakwildehttps://urldefense.com/v3/__https:/github.com/Oakwilde__;!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJAfDuMpdw$ I completely sympathize with your callouts! These are precisely why I initially wrote these packages. I've tried to maintain a balance between explicit support for various package, but ultimately, users will find the best experience in understanding how the optional prediction wrapper in each package works, as this will allow you to work across different packages, data types (e.g., arrays, data frames, or whatever the specific model package requires). My only real regret was developing these packages too far apart (i.e., pdp, vip, and fastshap) that the naming convention is not terribly consistent between them.

They plotting utilities in these libraries are also good, but like you said, users often have specific requirements and it's helpful to override the plotting. For instance, in pdp, you can simply set plot = FALSE (can't remember the default) to return a data frame that can be plotted in whichever library is most convenient. Our book, which we're revising here at some point, offers some additional exampleshttps://urldefense.com/v3/__https:/bradleyboehmke.github.io/HOML/iml.html*implementation-6__;Iw!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJArgVZahQ$ of supporting other models via custom prediction wrappers.

Each of these were published in the R Journal, so you can also find mroe extensive details and examples in those articles.

In any case, feel free to provide some examples when you get a chance and I'd be more than happy to help integrate them with these packages.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https:/github.com/bgreenwell/pdp/issues/134*issuecomment-2400472521__;Iw!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJAWQ-9vCg$, or unsubscribehttps://urldefense.com/v3/__https:/github.com/notifications/unsubscribe-auth/AHFKAWTEFXXACYDMKNWVQCDZ2QLFPAVCNFSM6AAAAABPSNBCZOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBQGQ3TENJSGE__;!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJDOrmMiyQ$. You are receiving this because you were mentioned.Message ID: @.***>

Oakwilde commented 1 month ago

Here is the code that with the help of chatgpt does everything I need. Getting there was not easy. Not sure now if `sigmoid’ is a problem or what predict() is doing. I now think the problem is in predict(). I have an easy solution to getting the rug plot and all else in the original units of the of the target predictor.

Richard

Get Data

setwd("/Volumes/Data in Recent Use/NewJPL/ForJune21") load("/Volumes/Data in Recent Use/NewJPL/ForJune21/AIRS.06/JuneJulyForecasting.rdata") summary(working1) work1<-working1[,c(1,6,15,20)]

Scale the predictors or in back propagation there are computation problems.

x_train <- scale(as.matrix(work1[,c(2,3,4)])) y_train <- as.matrix(work1[,1])

Scale the predictors or in back propagation there are computation problems.

This is fake test data through sampling for the training data. OK for now.

index <- sample(260, 100, replace = TRUE) x_test <- scale(as.matrix(x_train[index,])) y_test <- as.matrix(y_train[index]) TestData<-(cbind(y_test,x_test))

load library

library(keras3)

Fit a Keras Model

model <- keras_model_sequential() %>% layer_dense(units = 128, activation = 'relu', input_shape = c(ncol(x_train))) %>% layer_dense(units = 64, activation = 'relu') %>% layer_dense(units = 1, activation = 'sigmoid') # Seems to output p(Y=0) not p(Y=1)

model %>% compile( loss = 'binary_crossentropy', optimizer = optimizer_adam(), metrics = c('accuracy') )

model %>% fit( x_train, y_train, epochs = 10, batch_size = 32, validation_split = 0.2 )

Load libraries

library(pdp) library(ggplot2) library(iml) library(dplyr)

Define the custom prediction function for Keras model

getting the data in shape for predict()

predict_keras <- function(object, newdata) {

Convert newdata (data frame) to matrix for Keras model predictions

newdata_matrix <- as.matrix(newdata)

Ensure that the matrix has the correct number of columns (3 in this case)

if (ncol(newdata_matrix) != 3) { stop("Input data does not have the required number of features (3)") }

Get predictions from the model using a predictor matrix

predictions <- object %>% predict(newdata_matrix)

Return the predicted probabilities for fitting class "1" ?????????

return(as.vector(1 - predictions)) # Subtracting from 1 to get for class "1" }

Convert x_test to a data frame

feature_data <- as.data.frame(x_test) # Ensure x_test is now a data frame

Generate the partial dependence data (which includes ICE by default)

pdp_keras <- partial( object = model, # My Keras model pred.var = "temp8", # The feature for which you want the PDP pred.fun = predict_keras, # Custom prediction function for the Keras model train = feature_data, # Ensure x_test is now a data frame grid.resolution = 50, # Resolution of the grid (optional, can be adjusted) plot = FALSE # Disable automatic plotting )

A Fix: “Manually” aggregate the ICE data by averaging yhat over temp8

pdp_avg <- aggregate(yhat ~ temp8, data = pdp_keras, FUN = mean)

Calculate mean and standard deviation from the original data

mean_temp8 <- mean(work1$temp8) sd_temp8 <- sd(work1$temp8)

Add the original temp8 values to pdp_avg for plotting

pdp_avg$original_temp8 <- pdp_avg$temp8 * sd_temp8 + mean_temp8

Plot the averaged PDP using ggplot with original scale

ggplot(pdp_avg, aes(x = original_temp8, y = yhat)) + geom_smooth(color = "blue", size = 1.5, se = FALSE, method = "loess", span = 1/4) + geom_rug(data = work1, mapping = aes(x = temp8), sides = "b", inherit.aes = FALSE) + # Original data for rug plot ggtitle("Partial Dependence Plot for Temperature at Altitude 8") + xlab("Temperature (Kelvin)") + # Customize x-axis label ylab("Fitted Probability of a Heat Wave") + # Customize y-axis label theme( plot.title = element_text(hjust = 0.5, size = 16), # Center and increase title size axis.title.x = element_text(size = 14), # Increase x-axis label font size axis.title.y = element_text(size = 14), # Increase y-axis label font size axis.text = element_text(size = 12) # Increase axis text size ) + coord_cartesian(ylim = c(0, 1)) # Constrain y-axis between 0 and 1

From: Brandon Greenwell @.> Date: Tuesday, October 8, 2024 at 1:48 PM To: bgreenwell/pdp @.> Cc: Berk, Richard A @.>, Mention @.> Subject: Re: [bgreenwell/pdp] From Keras3 to package pdp in R (Issue #134)

@Oakwildehttps://urldefense.com/v3/__https:/github.com/Oakwilde__;!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJAfDuMpdw$ I completely sympathize with your callouts! These are precisely why I initially wrote these packages. I've tried to maintain a balance between explicit support for various package, but ultimately, users will find the best experience in understanding how the optional prediction wrapper in each package works, as this will allow you to work across different packages, data types (e.g., arrays, data frames, or whatever the specific model package requires). My only real regret was developing these packages too far apart (i.e., pdp, vip, and fastshap) that the naming convention is not terribly consistent between them.

They plotting utilities in these libraries are also good, but like you said, users often have specific requirements and it's helpful to override the plotting. For instance, in pdp, you can simply set plot = FALSE (can't remember the default) to return a data frame that can be plotted in whichever library is most convenient. Our book, which we're revising here at some point, offers some additional exampleshttps://urldefense.com/v3/__https:/bradleyboehmke.github.io/HOML/iml.html*implementation-6__;Iw!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJArgVZahQ$ of supporting other models via custom prediction wrappers.

Each of these were published in the R Journal, so you can also find mroe extensive details and examples in those articles.

In any case, feel free to provide some examples when you get a chance and I'd be more than happy to help integrate them with these packages.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https:/github.com/bgreenwell/pdp/issues/134*issuecomment-2400472521__;Iw!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJAWQ-9vCg$, or unsubscribehttps://urldefense.com/v3/__https:/github.com/notifications/unsubscribe-auth/AHFKAWTEFXXACYDMKNWVQCDZ2QLFPAVCNFSM6AAAAABPSNBCZOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBQGQ3TENJSGE__;!!IBzWLUs!WZ3FiIhsc5M32tNfxVpVrgu7Y7BQAruTiiI8rtX-AmtvSrBCM0WVZdwQ_HM-TXtBaob9Qs4gK1zqPfDpnJDOrmMiyQ$. You are receiving this because you were mentioned.Message ID: @.***>

brandongreenwell-8451 commented 1 month ago

Great, so you're able to produce the plot you need? I'll admit that understanding the underlying modeling package's predict function behavior can be out right frustrating with all the inconsistencies from base R. Did you open a ticket on the keras3 repo to see if they can help with understanding the behavior?

Oakwilde commented 1 month ago

Yes, I am good to go. But to move into production mode will take some thinking because I have hardcoded it in places. I have been in touch with the Karas3 site and the author/maintainer. Maybe some more cooperation across packages would help. Easy for me to say.

From: Brandon Greenwell @.> Date: Wednesday, October 9, 2024 at 9:57 AM To: bgreenwell/pdp @.> Cc: Berk, Richard A @.>, Mention @.> Subject: Re: [bgreenwell/pdp] From Keras3 to package pdp in R (Issue #134)

Great, so you're able to produce the plot you need? I'll admit that understanding the underlying modeling package's predict function behavior can be out right frustrating with all the inconsistencies from base R. Did you open a ticket on the keras3 repo to see if they can help with understanding the behavior?

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https:/github.com/bgreenwell/pdp/issues/134*issuecomment-2402425459__;Iw!!IBzWLUs!QcLycWaoLfn3riVqVcsJoOe25rGFfMz8saJ59lJ8AQYGzuzGvwklMBFfTKHISWmUtqeRkOTM7mk13wBQCaN--lbHAQ$, or unsubscribehttps://urldefense.com/v3/__https:/github.com/notifications/unsubscribe-auth/AHFKAWSM33XUMX5EXD3FG53Z2UY45AVCNFSM6AAAAABPSNBCZOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBSGQZDKNBVHE__;!!IBzWLUs!QcLycWaoLfn3riVqVcsJoOe25rGFfMz8saJ59lJ8AQYGzuzGvwklMBFfTKHISWmUtqeRkOTM7mk13wBQCaNnxJ7YMA$. You are receiving this because you were mentioned.Message ID: @.***>

brandongreenwell-8451 commented 1 month ago

Awesome! I'll keep this thread open for now in case there's an opportunity to make it easier for keras3 users.

Oakwilde commented 1 month ago

There is one thing I forgot to mention. The code I sent required another couple of lines because I could not override the ice option. It would only provide ice plots. They are great, but not what I needed.

From: Brandon Greenwell @.> Date: Wednesday, October 9, 2024 at 10:10 AM To: bgreenwell/pdp @.> Cc: Berk, Richard A @.>, Mention @.> Subject: Re: [bgreenwell/pdp] From Keras3 to package pdp in R (Issue #134)

Awesome! I'll keep this thread open for now in case there's an opportunity to make it easier for keras3 users.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https:/github.com/bgreenwell/pdp/issues/134*issuecomment-2402459327__;Iw!!IBzWLUs!QxUQxP15B61kEtujTOb5N8q2-wguq9xG4N9Xi9QX-yfvGBtxji3PMkUstS8Ju059w2i7NfKLUx0X83VzjGzLGw-bqA$, or unsubscribehttps://urldefense.com/v3/__https:/github.com/notifications/unsubscribe-auth/AHFKAWTBFC623QS5SWGX7YLZ2U2MNAVCNFSM6AAAAABPSNBCZOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBSGQ2TSMZSG4__;!!IBzWLUs!QxUQxP15B61kEtujTOb5N8q2-wguq9xG4N9Xi9QX-yfvGBtxji3PMkUstS8Ju059w2i7NfKLUx0X83VzjGxf89l7BA$. You are receiving this because you were mentioned.Message ID: @.***>