rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
31 stars 2 forks source link

[Breaking] Rename confusing terms such as `Split` #67

Closed rikhuijzer closed 11 months ago

rikhuijzer commented 11 months ago

This PR renames Split to SubClause and TreePath to Clause. In turn, this renaming makes things much more intuitive since it sticks to the end-result terms instead of internal representation (decision tree) terms. Also, the Split datastructure is finally fixed to directly contain the split value and feature instead of having those in a nested data structure called SplitPoint.

rikhuijzer commented 11 months ago

Will take the new API logic out and save it for a next PR.

src/extract.jl:

"Estimate the importance of a rule."
function _rule_importance(weight::Number, rule::Rule)
    importance = 0.0
    thens = rule.then::Vector{Float64}
    otherwises = rule.otherwise::Vector{Float64}
    for (then, otherwise) in zip(thens, otherwises)
        importance += weight * abs(then - otherwise)
    end
    return importance
end

"""
    feature_importance(
        model::StableRules,
        feature_name::AbstractString
    )

Estimate the importance of the given `feature_name`.
The aim of this function is to satisfy the following property:

> Given two features X and Y, if X has more effect on the outcome, then
> feature_importance(model, X) > feature_importance(model, Y).

This function provides only an estimation of the importance because
the effect on the outcome depends on the data.
"""
function feature_importance(
        model::StableRules,
        feature_name::String
    )
    importance = 0.0
    for (i, rule) in enumerate(model.rules)
        for subclause::SubClause in _subclauses(rule)
            if _feature_name(subclause)::String == feature_name
                weight = model.weights[i]
                importance += _rule_importance(weight, rule)
            end
        end
    end
    return importance
end

function feature_importance(model::StableRules, feature_name::AbstractString)
    return feature_importance(model, string(feature_name)::String)
end

and test/extract.jl:

function _haberman_data()
    df = haberman()
    X = MLJBase.table(MLJBase.matrix(df[:, Not(:survival)]))
    y = categorical(df.survival)
    (X, y)
end

X, y = _haberman_data()

classifier = StableRulesClassifier(; max_depth=2, max_rules=8, n_trees=1000, rng=_rng())
mach = machine(classifier, X, y)
fit!(mach)

model = mach.fitresult::StableRules

importance = SIRUS.feature_importance(model, "x1")