rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
30 stars 2 forks source link

Remove rule sorting step #68

Closed rikhuijzer closed 9 months ago

rikhuijzer commented 9 months ago

It turns out that the whole rule sorting logic in _sort_by_frequency wasn't necessary, so let's get rid of it. This PR also moves some code around for clarity.

rikhuijzer commented 9 months ago

I'm bringing over extract another time because now the code has been cleaned up:


"Estimate the importance of a rule."
function _rule_importance(weight::Number, rule::Rule)
    # TODO: THIS SHOULD USE THE GAP SIZE FUNCTION.
    importance = 0.0
    thens = rule.then::Vector{Float64}
    otherwises = rule.otherwise::Vector{Float64}
    @assert length(thens) == length(otherwises)
    n_classes = length(thens)
    for (then, otherwise) in zip(thens, otherwises)
        importance += weight * abs(then - otherwise)
    end
    return importance / n_classes
end

"""
    feature_importance(
        models::Union{StableRules, Vector{StableRules}},
        feature_name::AbstractString
    )

Estimate the importance of the given `feature_name`.
The aim is to satisfy the following property, so that the features can be
ordered by importance:

> Given two features A and B, if A has more effect on the outcome, then
> feature_importance(model, A) > feature_importance(model, B).

!!! note
    This function provides only an importance _estimate_ because the effect on
    the outcome depends on the data, and because it doesn't take into account
    that a feature can have a lower effect if it is in a clause together with
    another subclause.
"""
function feature_importance(
        model::StableRules,
        feature_name::String
    )
    importance = 0.0
    for (i, rule) in enumerate(model.rules)
        for subclause::SubClause in _subclauses(rule)
            if _feature_name(subclause)::String == feature_name
                weight = model.weights[i]
                importance += _rule_importance(weight, rule)
            end
        end
    end
    return importance
end

function feature_importance(model::StableRules, feature_name::AbstractString)
    return feature_importance(model, string(feature_name)::String)
end

function feature_importance(
        models::Vector{<:StableRules},
        feature_name::String
    )
    importance = 0.0
    for model in models
        importance += feature_importance(model, feature_name)
    end
    return importance / length(models)
end

function feature_importance(models::Vector{<:StableRules}, feature_name::AbstractString)
    return feature_importance(models, string(feature_name)::String)
end

"""
    feature_importances(
        models::Union{StableRules, Vector{StableRules}}
        feature_names
    )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}

Return the feature names and importances, sorted by feature importance in descending order.
"""
function feature_importances(
        models::Union{StableRules, Vector{StableRules}},
        feature_names::Vector{String}
    )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
    @assert length(unique(feature_names)) == length(feature_names)
    importances = map(feature_names) do feature_name
        importance = feature_importance(models, feature_name)
        (; feature_name, importance)
    end
    alg = Helpers.STABLE_SORT_ALG
    return sort(importances; alg, by=last, rev=true)
end

function feature_importances(
        models::Union{StableRules, Vector{StableRules}},
        feature_names
    )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
    return feature_importances(models, string.(feature_names))
end
function _haberman_data()
    df = haberman()
    X = MLJBase.table(MLJBase.matrix(df[:, Not(:survival)]))
    y = categorical(df.survival)
    (X, y)
end

X, y = _haberman_data()

mach = let
    classifier = StableRulesClassifier(; q=3, max_depth=1, max_rules=8, n_trees=1000, rng=_rng())
    mach = machine(classifier, X, y)
    fit!(mach; force=true)
end

# TODO MAKE THIS MODEL MANUALLY TO NOT DEPEND ON THE FITTED MODEL

model = mach.fitresult::StableRules
# StableRules model with 8 rules:
#  if X[i, :x3] < 8.0 then 0.084 else 0.03 +
#  if X[i, :x3] < 14.0 then 0.147 else 0.098 +
#  if X[i, :x3] < 2.0 then 0.073 else 0.047 +
#  if X[i, :x3] < 4.0 then 0.079 else 0.048 +
#  if X[i, :x3] < 1.0 then 0.076 else 0.06 +
#  if X[i, :x2] < 1959.0 then 0.006 else 0.008 +
#  if X[i, :x1] < 38.0 then 0.029 else 0.024 +
#  if X[i, :x1] < 42.0 then 0.052 else 0.043
# and 2 classes: [0, 1].
# Note: showing only the probability for class 1 since class 0 has probability 1 - p.

importance = feature_importance(model, "x1")
# Based on the numbers above.
expected = ((0.029 - 0.024) + (0.052 - 0.043))
@test importance ≈ expected atol=0.01

@test feature_importance([model, model], "x1") ≈ expected atol=0.01
@test only(feature_importances(model, ["x1"])).importance ≈ expected atol=0.01