rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
31 stars 2 forks source link

Easily access condition and consequence (then/otherwise) of a Rule #44

Closed Zapiano closed 10 months ago

Zapiano commented 1 year ago

After I fit a model, e.g., StableRulesClassifier, I get StableRules object. It would be interesting to have an easy way to access the output of the models: the conditions (if some_feature > some_value && some_other_feature < some_other_value) and the consequences (then probability_1 otherwise probability_2) in a more direct way, such as an array.

I looked at the documentation and couldn't find anything like this (maybe I am missing something?). This would be useful for plotting the rules and also to use these rules in an automated workflow as input to other functions.

This could be done by a pair of functions function conditions(stable_rules::StableRules)::Vector{} and function consequences(stable_rules::StableRules)::Vector{} (the names could be different).

If this makes sense, I have already implemented this functionality in another project (https://github.com/open-AIMS/ADRIA.jl/blob/225d0e82ca428c125493b14ffa3d0c8fef5046e8/src/analysis/rule_extraction.jl#L47) and I can open a PR implementing these here.

rikhuijzer commented 1 year ago

Hi there. Nice to see that you are using SIRUS.jl and I hope you find it useful. Thank you for providing feedback via this issue.

What I did for my plotting function is visible when clicking "Edit on GitHub" at the top of the binary classification example. That file uses the following function

function _odds_plot(e::PerformanceEvaluation)
    w, h = (1000, 300)
    fig = Figure(; resolution=(w, h))
    grid = fig[1, 1:2] = GridLayout()

    fitresults = getproperty.(e.fitted_params_per_fold, :fitresult)
    feature_names = String[]
    for fitresult in fitresults
        for rule in fitresult.rules
            name = only(rule.path.splits).splitpoint.feature_name
            push!(feature_names, name)
        end
    end

    names = sort(unique(feature_names))
    subtitle = "Ratio"

    max_height = maximum(maximum.(getproperty.(fitresults, :weights)))

    importances = _sum_weights.(Ref(fitresults), names)

    matching_rules = DataFrame(; names, importance=importances)
    sort!(matching_rules, :importance; rev=true)
    names = matching_rules.names
    l = length(names)

    for (i, feature_name) in enumerate(names)
        yticks = (1:1, [feature_name])
        ax = i == l ? 
            Axis(grid[i, 1:3]; yticks, xlabel="Ratio") : 
            Axis(grid[i, 1:3]; yticks)
        vlines!(ax, [0]; color=:gray, linestyle=:dash)
        xlims!(ax, -1, 1)
        ylabel = feature_name

        name = feature_name

        rules_weights = map(fitresults) do fitresult
            index = _rule_index(fitresult, feature_name)
            isnothing(index) && return nothing
            rule = fitresult.rules[index]::SIRUS.Rule
            return (rule, fitresult.weights[index])
        end
        rw::Vector{Tuple{SIRUS.Rule,Float64}} = 
            filter(!isnothing, rules_weights)
        thresholds = _threshold.(first.(rw))
        t_mean = round(mean(thresholds); digits=1)
        t_std = round(std(thresholds); digits=1)

        for (rule, weight) in rw
            left = last(rule.then)::Float64
            right = last(rule.otherwise)::Float64
            t::Float64 = _threshold(rule)
            ratio = log((right) / (left))
            # area = πr²
            markersize = 50 * sqrt(weight / π)
            scatter!(ax, [ratio], [1]; color=:black, markersize)
        end
        hideydecorations!(ax; ticklabels=false)

        axr = i == l ?
            Axis(grid[i, 4:5]; xlabel="Location") :
            Axis(grid[i, 4:5])
        D = data[:, feature_name]
        hist!(axr, D; scale_to=1)
        vlines!(axr, thresholds; color=:black, linestyle=:dash)

        if i < l
            hidexdecorations!(ax)
        else
            hidexdecorations!(ax; ticks=false, ticklabels=false)
        end

        hideydecorations!(axr)
        hidexdecorations!(axr; ticks=false, ticklabels=false)
    end

    rowgap!(grid, 5)
    colgap!(grid, 50)
    return fig
end;

This function should work for any MLJ.PerformanceEvaluation, but I'll be the first to admit that the function is not perfect and the visualization is not perfect. Also it only works for binary classification. Maybe it's useful for you.

Anyway, you write that we probably should have a more clear interface to the internals and I agree with you. However, adding additional functions is a very risky business in software since decisions are hard to reverse. To reverse an API decision, I would need to release a new breaking version meaning that everyone has to manually move their code to the new version.

So I propose that we keep this issue for a bit as a Request For Comments (RFC). Based on your code, we probably want something like this:

"""
    _condition(rules::T, index::Int64) where {T<:SIRUS.StableRules}

Vector containing condition clauses. Each condition clause is a vector with three
components: a feature_name::String, a direction::String (< or ≤) and a value:<Float64

# Arguments
- `rules` : SIRUS.StableRules object containing all rules
- `index` : Index of the rule

# Returns
Vector of Rule condition clauses (each one being a vector itself).
"""

and

"""
    _consequent(rules::SIRUS.StableRules{Float64}, index::Int64)

Vector of Rule consequent with two components: the probability of the 'then' values and probability of the 'else' values.

# Arguments
- `rules` : SIRUS.StableRules object containing all rules
- `index` : Index of the rule

# Returns
Probabilities vectors, one for Rule condition == true, one for Rule condition == false.
"""

I propose to let this sit for at least a week and take further actions after that. Apologies for being bureaucratic; I think this is the right way given that making wrong decisions here has such a high cost for clients.

Zapiano commented 1 year ago

About the risk of adding these functions: I completely understand and agree with you. Should I add (RFC) to the issue title?

About the docstrings you wrote, I have only two comments:

  1. I think these functions should be public, otherwise people would have to use private/internal functions to extract the rules after running the algorithm.
  2. We could also return the direction, inside each clause, as a Symbol (:L or :R), as SIRUS already handles directions as Symbols right now.

What are your thoughts on that?

rikhuijzer commented 10 months ago

Okay done now in #74, @Zapiano. Sorry for taking so long! I was very busy. You probably moved and that's okay. Just wanted to apologize for taking so long to respond.

Zapiano commented 10 months ago

Hey @rikhuijzer, I'm still using SIRUS in my project, so thanks for letting me know about that!

Cheers