dandls / counterfactuals

counterfactuals: An R package for Counterfactual Explanation Methods
https://dandls.github.io/counterfactuals/
GNU Lesser General Public License v3.0
21 stars 4 forks source link

Proposal: Different workflow for subsetting valid counterfactuals #24

Closed andreash0 closed 1 year ago

andreash0 commented 1 year ago

At the moment, we have three methods/fields for subsetting valid counterfactuals:

library(counterfactuals)

rf = randomForest::randomForest(Species ~ ., data = iris)
predictor = iml::Predictor$new(rf, type = "prob")
moc_classif = MOCClassif$new(predictor, n_generations = 15L, quiet = TRUE)
cfactuals = moc_classif$find_counterfactuals(
  x_interest = iris[150L, ], desired_class = "versicolor", desired_prob = c(0.5, 1)
)

cfactuals$data
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width
#>  1:     5.900000           3     5.100000    1.668422
#>  2:     5.900000           3     4.734607    1.800000
#>  3:     5.900000           3     5.100000    1.719278
#>  4:     5.900000           3     4.828192    1.800000
#>  5:     5.900000           3     4.803854    1.800000
#>  6:     5.900000           3     4.805899    1.800000
#>  7:     5.900000           3     4.831830    1.800000
#>  8:     5.900000           3     4.996840    1.800000
#>  9:     5.903968           3     5.100000    1.800000
#> 10:     5.900000           3     5.100000    1.797356
cfactuals$fulldata
#> `$data` was not subsetted yet
#> NULL
cfactuals$subset_to_valid()
cfactuals$subsetted
#> [1] TRUE
cfactuals$data
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width
#> 1:          5.9           3          5.1    1.668422
cfactuals$fulldata
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width
#>  1:     5.900000           3     5.100000    1.668422
#>  2:     5.900000           3     4.734607    1.800000
#>  3:     5.900000           3     5.100000    1.719278
#>  4:     5.900000           3     4.828192    1.800000
#>  5:     5.900000           3     4.803854    1.800000
#>  6:     5.900000           3     4.805899    1.800000
#>  7:     5.900000           3     4.831830    1.800000
#>  8:     5.900000           3     4.996840    1.800000
#>  9:     5.903968           3     5.100000    1.800000
#> 10:     5.900000           3     5.100000    1.797356

I was thinking if we should simplify this workflow as follows:

library(counterfactuals)

rf = randomForest::randomForest(Species ~ ., data = iris)
predictor = iml::Predictor$new(rf, type = "prob")
moc_classif = MOCClassif$new(predictor, n_generations = 15L, quiet = TRUE)
cfactuals = moc_classif$find_counterfactuals(
  x_interest = iris[150L, ], desired_class = "versicolor", desired_prob = c(0.5, 1)
)

cfactuals$data
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width
#>  1:     5.900000           3     5.100000    1.668422
#>  2:     5.900000           3     4.734607    1.800000
#>  3:     5.900000           3     5.100000    1.719278
#>  4:     5.900000           3     4.828192    1.800000
#>  5:     5.900000           3     4.803854    1.800000
#>  6:     5.900000           3     4.805899    1.800000
#>  7:     5.900000           3     4.831830    1.800000
#>  8:     5.900000           3     4.996840    1.800000
#>  9:     5.903968           3     5.100000    1.800000
#> 10:     5.900000           3     5.100000    1.797356

# Setting the flag tells all methods to consider only valid counterfactuals
cfactuals$consider_valid_only <- TRUE

cfactuals$data
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width
#> 1:          5.9           3          5.1    1.668422

cfactuals$consider_valid_only <- FALSE

cfactuals$data
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width
#>  1:     5.900000           3     5.100000    1.668422
#>  2:     5.900000           3     4.734607    1.800000
#>  3:     5.900000           3     5.100000    1.719278
#>  4:     5.900000           3     4.828192    1.800000
#>  5:     5.900000           3     4.803854    1.800000
#>  6:     5.900000           3     4.805899    1.800000
#>  7:     5.900000           3     4.831830    1.800000
#>  8:     5.900000           3     4.996840    1.800000
#>  9:     5.903968           3     5.100000    1.800000
#> 10:     5.900000           3     5.100000    1.797356

This way, users would only need to set one flag to use all or only valid counterfactuals.

@dandls: What do you think? :)

dandls commented 1 year ago

I agree that the current procedure involves many fields, adding an unnecessary degree of complexity. Considering your idea, I like that the user is able to switch between all counterfactuals and only valid counterfactuals, but I am not sure if I like that flipping the value of a field triggers a set of actions and overwrites other fields.
I have the feeling that, as a user, if I want to actively change my displayed results, this should be done with a method.

Alternatives:

What do you think, @andreash0?

andreash0 commented 1 year ago

I think your proposed solution with the two methods is great @dandls 👍

dandls commented 1 year ago

Great, I will work on it.