JuliaTrustworthyAI / CounterfactualExplanations.jl

A package for Counterfactual Explanations and Algorithmic Recourse in Julia.
https://juliatrustworthyai.github.io/CounterfactualExplanations.jl/
MIT License
117 stars 7 forks source link

PoC of Customizability #323

Closed VincentPikand closed 4 months ago

VincentPikand commented 11 months ago

In this example I intend to show how we should use multiple dispatch as a base for the entire generation of counterfactual explanations. Idea is that every step is customizable. Here, I use the simple example of the termination step to prove my point.

struct CounterfactualExplanation
  # all other fields are omitted
  termination::Termination
end

function generate_counterfactual(ce::CounterfactualExplanation)
  while !terminated(ce.termination)
    println("calculating...")
  end
  println("terminated")
end

This block is the juicer: it is the contract we provide to the user. If they create a struct that has all of the functions implemented, we can work with it. In this case, it's just the terminated function.

abstract type Termination end

terminated(t::Termination) = error("not provided")

This is how a user would use our library. They define their own way of terminating the flow, among other structs, like generator.

mutable struct IterationCountTermination <: Termination 
  c::Integer
  maxiter::Integer
end

function terminated(t::IterationCountTermination)
  t.c += 1
  return t.c == t.maxiter
end

ce = CounterfactualExplanation(iterbasedtermination)
generate_counterfactual(ce)
julia> generate_counterfactual(ce)
calculating...
calculating...
calculating...
calculating...
calculating...
calculating...
calculating...
calculating...
calculating...
terminated

Notice how in the actual CounterfactualExplanation struct we have a maxiter field, but in this example, it's "pushed down" to the Termination struct.

You might be asking yourself "Well, iteration count is something very basic a user shouldn't define themselves, it's just tedious." I completely agree. However, this is exactly how we will provide the implementation to it, just like the end user themselves would. We can make a Dict where we have a few of the most common solutions. The important part is, we preserve the ability to customize, while still allowing the user to choose one of our implementations.

pat-alt commented 11 months ago

Thanks, this definitely looks interesting. A couple of quick thoughts/questions:

  1. I know you're oversimplifying but here ce = CounterfactualExplanation(iterbasedtermination) you assume that the outer constructor takes <: Termination as it's first positional argument. Of course, there's more things we need.
  2. I do like the idea of handing things like convergence conditions through custom types, but we should also be careful not to overdo it (adding too many custom types for our catalogue of methods in this case). Still, with respect to converge conditions in particular this may be worth doing.
  3. This still does not solve the issue you were originally concerned about, namely that generate_counterfactual currently takes way to many keyword arguments for the user to keep track of it all.
  4. Ultimately, we do need all the information currently provided through keyword args. But to reduce the burden on users, one thing we could to is to break the various keyword args down into meaningful groups.

Regarding point 4, I'm thinking about something along these lines. Keep the function largely as is (but the last three keyword args are specific to PROBE so should definitely not be here).

function generate_counterfactual(
    x::AbstractArray,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractFittedModel,
    generator::AbstractGenerator;
    num_counterfactuals::Int=1,
    initialization::Symbol=:add_perturbation,
    generative_model_params::NamedTuple=(;),
    max_iter::Int=100,
    decision_threshold::AbstractFloat=0.5,
    gradient_tol::AbstractFloat=parameters[:τ],
    min_success_rate::AbstractFloat=parameters[:min_success_rate],
    converge_when::Symbol=:decision_threshold,
    timeout::Union{Nothing,Int}=nothing,
    invalidation_rate::AbstractFloat=0.1,
    learning_rate::AbstractFloat=1.0,
    variance::AbstractFloat=0.01,
)

Then, to break it down for users, add keyword containers for each meaningful group:

@kwdef struct GeneratorParams
    opt::Flux.Optimiser
    decision_threshold::AbstractFloat=0.5
end

@kwdef struct SearchParams
    gradient_tol::AbstractFloat=parameters[:τ],
    max_iter::Int=100,
    min_success_rate::AbstractFloat=parameters[:min_success_rate],
    converge_when::Symbol=:decision_threshold,
    timeout::Union{Nothing,Int}=nothing,
end

Then users can define these step by step. Additionally, we need to overload the generate_counterfactual method to something like generate_counterfactual(args...;gen_params::GeneratorParams,search_params::SearchParams). Alternatively, (and perhaps more in line with what your looking for) we can handle everything through outer constructors of CounterfactualExplanation first and then add a method generate_counterfactual(ce::CounterfactualExplanation).

I'm not sure how else we could do this to be honest. If you do have a better idea, feel free to try it out, of course. It may be easier to first focus on pruning the package as discussed, before we turn to core method changes like this one?