slds-lmu / ame

average marginal effects for machine learning
Other
1 stars 1 forks source link

add pdp functions #4

Closed giuseppec closed 6 years ago

giuseppec commented 6 years ago
library(mlr)

lrn = makeLearner("regr.ksvm")
bh = getTaskData(bh.task)

mod = train(lrn, bh.task)

# 1) same as partial dependence without uniform grid
p = lapply(bh$zn, function(i) predictModifiedData(x = i, feature = "zn", data = bh, model = mod$learner.model, 
  predict.fun = function(model, newdata) predict(model, newdata)))
unlist(p) 
# mlr - ame branch:
pd = generatePartialDependenceData(mod, bh.task, features = "zn", derivative = FALSE, uniform = FALSE, gridsize = getTaskSize(bh.task))
mean(pd$data$medv)

# 2) 
# mlr - ame branch:
pd = generatePartialDependenceData(mod, bh.task, features = "zn", derivative = FALSE, uniform = TRUE, gridsize = getTaskSize(bh.task))
# verwende gridpunkte statt bh$zn

# 3)
pd = generatePartialDependenceData(mod, bh.task, features = "zn", derivative = TRUE, uniform = FALSE, gridsize = getTaskSize(bh.task))
ame.pd = mean(pd$data$medv)
ame = computeAME(mod, data = bh, features = "zn")
pd.deriv = lapply(bh$zn, function(i) derivative(x = i, feature = "zn", data = bh, model = mod$learner.model, 
  predict.fun = function(object, newdata) predict(object, newdata)))
mean(unlist(pd.deriv))
# nutze derivative statt predictModifiedData

# 4) 
pd = generatePartialDependenceData(mod, bh.task, features = "zn", derivative = TRUE, uniform = TRUE, gridsize = getTaskSize(bh.task))
mean(pd$data$medv)
# nutze derivative statt predictModifiedData + verwende gridpunkte statt bh$zn

# 1-4) nochmal mit jeweils gridsize <= getTaskSize(bh.task)

# TODO: mit welchen der obigen fälle 3 oder 4 ist AME äquivalent??? Oder auch nicht???

library(mmpf)
fit = mod$learner.model
p = marginalPrediction(bh, "zn", c(nrow(bh), nrow(bh)), fit, uniform = FALSE,
  predict.fun = function(object, newdata) predict(object, newdata = newdata),
  aggregate.fun = function(x) c("mean" = mean(x)))

mean(p$V1.mean)

load_all("../mlr")
pd$data$zn
ame = computeAME(mod, data = bh, features = "zn")
mean(ame$zn)

plotPartialDependence(pd)
BodoBurger commented 6 years ago

interval branch --> function computePD(), plotPD()