Add an API to obtain rules for plotting. Currently, the following code
function _rule_index(model::StableRules, feature_name::String)
for (i, rule) in enumerate(model.rules)
if only(rule.path.splits).splitpoint.feature_name == feature_name
return i
end
end
return nothing
end
# Renamed to `SIRUS.sum_weights(fitresults::Vector{StableRules}, name::AbstractString)`.
function _sum_weights(fitresults::Vector{<:StableRules}, name::AbstractString)
indexes = _rule_index.(fitresults, Ref(name))
return sum([isnothing(index) ? 0 : fitresults[i].weights[index] for (i, index) in enumerate(indexes)])
end
function _remove_nato_name(name::String)
if contains(name, '(')
parts = split(name, ' ')
return join(parts[1:end-1], ' ')
else
return name
end
end
function _threshold(rule)
sp = only(rule.path.splits).splitpoint
return sp.value
end
function odds_plot(
e::PerformanceEvaluation,
data::DataFrame,
pretty_name::Function
)
w, h = (800, 1000)
fig = Figure(; resolution=(w, h))
grid = fig[1, 1:3] = GridLayout()
fitresults = getproperty.(e.fitted_params_per_fold, :fitresult)
feature_names = String[]
for fitresult in fitresults
for rule in fitresult.rules
name = only(rule.path.splits).splitpoint.feature_name
push!(feature_names, name)
end
end
names = sort(unique(feature_names))
subtitle = "Ratio"
max_height = maximum(maximum.(getproperty.(fitresults, :weights)))
importances = _sum_weights.(Ref(fitresults), names)
matching_rules = DataFrame(; names, importance=importances)
sort!(matching_rules, :importance; rev=true)
names = matching_rules.names[1:15]
l = length(names)
pretty_names = [pretty_name(n) for n in names]
for (i, feature_name) in enumerate(names)
yticks = (1:1, [pretty_names[i]])
ax = i == l ?
Axis(grid[i, 1:2]; yticks, xlabel="Ratio") :
Axis(grid[i, 1:2]; yticks)
vlines!(ax, [0]; color=:gray, linestyle=:dash)
xlims!(ax, -1.01, 1.01)
ylabel = feature_name
name = _remove_nato_name(pretty_name(feature_name))
nested_rules_weights = map(fitresults) do fitresult
subresult = Tuple{SIRUS.Rule,Float64}[]
zipped = zip(fitresult.rules, fitresult.weights)
for (rule, weight) in zipped
feat_name = only(rule.path.splits).splitpoint.feature_name
if feat_name == feature_name
push!(subresult, (rule, weight))
end
end
subresult
end
rules_weights = Tuple{SIRUS.Rule,Float64}[]
for nested in nested_rules_weights
isnothing(nested) && continue
for rule_weight in nested
push!(rules_weights, rule_weight)
end
end
rw::Vector{Tuple{SIRUS.Rule,Float64}} =
filter(!isnothing, rules_weights)
thresholds = _threshold.(first.(rw))
t_mean = round(mean(thresholds); digits=1)
t_std = round(std(thresholds); digits=1)
for (rule, weight) in rw
left = last(rule.then_probs)::Float64
right = last(rule.else_probs)::Float64
t::Float64 = _threshold(rule)
ratio = log((right) / (left))
# area = πr²
markersize = 50 * sqrt(weight / π)
scatter!(ax, [ratio], [1]; color=:black, markersize)
end
hideydecorations!(ax; ticklabels=false)
axr = i == l ?
Axis(grid[i, 3:5]; xlabel="Location") :
Axis(grid[i, 3:5])
D = data[:, feature_name]
hist!(axr, D; scale_to=1, color=:white, strokewidth=1, strokecolor=:black)
vlines!(axr, thresholds; color=:black, linestyle=:dash)
if i < l
hidexdecorations!(ax)
else
hidexdecorations!(ax; ticks=false, ticklabels=false)
end
hideydecorations!(axr)
hidexdecorations!(axr; ticks=false, ticklabels=false)
end
rowgap!(grid, 5)
return fig
end;
Add an API to obtain rules for plotting. Currently, the following code
Produces the following plot
Apart from the bug that causes all points to be on the left, this should provide a good basis for the API together with https://github.com/rikhuijzer/SIRUS.jl/issues/44.