rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
31 stars 2 forks source link

Add API to obtain rules for visualizations #66

Closed rikhuijzer closed 10 months ago

rikhuijzer commented 11 months ago

Add an API to obtain rules for plotting. Currently, the following code

function _rule_index(model::StableRules, feature_name::String)
    for (i, rule) in enumerate(model.rules)
        if only(rule.path.splits).splitpoint.feature_name == feature_name
            return i
        end
    end
    return nothing
end

# Renamed to `SIRUS.sum_weights(fitresults::Vector{StableRules}, name::AbstractString)`.
function _sum_weights(fitresults::Vector{<:StableRules}, name::AbstractString)
    indexes = _rule_index.(fitresults, Ref(name))
    return sum([isnothing(index) ? 0 : fitresults[i].weights[index] for (i, index) in enumerate(indexes)])
end

function _remove_nato_name(name::String)
    if contains(name, '(')
        parts = split(name, ' ')
        return join(parts[1:end-1], ' ')
    else
        return name
    end
end

function _threshold(rule)
    sp = only(rule.path.splits).splitpoint
    return sp.value
end

function odds_plot(
        e::PerformanceEvaluation,
        data::DataFrame,
        pretty_name::Function
    )
    w, h = (800, 1000)
    fig = Figure(; resolution=(w, h))
    grid = fig[1, 1:3] = GridLayout()

    fitresults = getproperty.(e.fitted_params_per_fold, :fitresult)
    feature_names = String[]
    for fitresult in fitresults
        for rule in fitresult.rules
            name = only(rule.path.splits).splitpoint.feature_name
            push!(feature_names, name)
        end
    end

    names = sort(unique(feature_names))
    subtitle = "Ratio"

    max_height = maximum(maximum.(getproperty.(fitresults, :weights)))

    importances = _sum_weights.(Ref(fitresults), names)

    matching_rules = DataFrame(; names, importance=importances)
    sort!(matching_rules, :importance; rev=true)
    names = matching_rules.names[1:15]
    l = length(names)

    pretty_names = [pretty_name(n) for n in names]

    for (i, feature_name) in enumerate(names)
        yticks = (1:1, [pretty_names[i]])
        ax = i == l ?
            Axis(grid[i, 1:2]; yticks, xlabel="Ratio") :
            Axis(grid[i, 1:2]; yticks)
        vlines!(ax, [0]; color=:gray, linestyle=:dash)
        xlims!(ax, -1.01, 1.01)
        ylabel = feature_name

        name = _remove_nato_name(pretty_name(feature_name))

        nested_rules_weights = map(fitresults) do fitresult
            subresult = Tuple{SIRUS.Rule,Float64}[]
            zipped = zip(fitresult.rules, fitresult.weights)
            for (rule, weight) in zipped
                feat_name = only(rule.path.splits).splitpoint.feature_name
                if feat_name == feature_name
                    push!(subresult, (rule, weight))
                end
            end
            subresult
        end
        rules_weights = Tuple{SIRUS.Rule,Float64}[]
        for nested in nested_rules_weights
            isnothing(nested) && continue
            for rule_weight in nested
                push!(rules_weights, rule_weight)
            end
        end
        rw::Vector{Tuple{SIRUS.Rule,Float64}} =
            filter(!isnothing, rules_weights)
        thresholds = _threshold.(first.(rw))
        t_mean = round(mean(thresholds); digits=1)
        t_std = round(std(thresholds); digits=1)

        for (rule, weight) in rw
            left = last(rule.then_probs)::Float64
            right = last(rule.else_probs)::Float64
            t::Float64 = _threshold(rule)
            ratio = log((right) / (left))
            # area = πr²
            markersize = 50 * sqrt(weight / π)
            scatter!(ax, [ratio], [1]; color=:black, markersize)
        end
        hideydecorations!(ax; ticklabels=false)

        axr = i == l ?
            Axis(grid[i, 3:5]; xlabel="Location") :
            Axis(grid[i, 3:5])
        D = data[:, feature_name]
        hist!(axr, D; scale_to=1, color=:white, strokewidth=1, strokecolor=:black)
        vlines!(axr, thresholds; color=:black, linestyle=:dash)

        if i < l
            hidexdecorations!(ax)
        else
            hidexdecorations!(ax; ticks=false, ticklabels=false)
        end

        hideydecorations!(axr)
        hidexdecorations!(axr; ticks=false, ticklabels=false)
    end

    rowgap!(grid, 5)
    return fig
end;

Produces the following plot

image

Apart from the bug that causes all points to be on the left, this should provide a good basis for the API together with https://github.com/rikhuijzer/SIRUS.jl/issues/44.