Closed ablaom closed 11 months ago
One drawback is that that model.over_sampler.feature
is exposed to the user but shouldn't get altered (it should always be :target
)
any movement on this? This sounds like some data preparation utility that can be provided by a third party ML utility package?
We have some functionality in ClassImbalance.jl, but there is not yet an MLJ interface for that package. Help is welcome.
https://github.com/bcbi/ClassImbalance.jl/issues/85
show stopper
Yeah the package needs to be updated and modernized.
I've implemented SMOTE in Resample.jl
. It has a very basic API, but is built with speed in mind and it uses the Tables interface.
It's been a while since I posted the above POC. Here's an updated version, based on more recent versions of the packages, and some other mild changes. You'll need MLJBase >= 0.21.12, and MLJDecisionTreeInterface in your env.
using MLJ, Tables
import MLJBase, StatsBase
## A QUICK AND DIRTY OVERSAMPLER FOR ILLUSTRATION
mutable struct NaiveOversampler <: Static
ratio::Float64
end
NaiveOversampler(; ratio=1.0) = NaiveOversampler(ratio)
function MLJBase.transform(oversampler::NaiveOversampler, verbosity, X, y)
d = StatsBase.countmap(y)
counts = sort(collect(d), by=pair->last(pair))
minority_class = first(counts) |> first
dominant_class = last(counts) |> first
nextras = max(
0,
round(Int, oversampler.ratio*d[dominant_class] - d[minority_class]),
)
all_indices = eachindex(y)
minority_indices = all_indices[y .== minority_class]
extra_indices = rand(minority_indices, nextras)
over_indices = vcat(all_indices, extra_indices)
Xover = Tables.subset(X, over_indices) |> Tables.materializer(X)
yover = y[over_indices]
return Xover, yover
end
# demonstration:
X = (x1=1:4, x2=5:8)
y = coerce([true, false, true, true], Multiclass)
StatsBase.countmap(y)
# Dict{CategoricalArrays.CategoricalValue{Bool, UInt32}, Int64} with 2 entries:
# false => 1
# true => 3
naive = NaiveOversampler()
mach = machine(naive) # static transformers have no training arguments
Xover, yover = transform(mach, X, y)
StatsBase.countmap(yover)
# Dict{CategoricalArrays.CategoricalValue{Bool, UInt32}, Int64} with 2 entries:
# false => 3
# true => 3
## COMPOSITE FOR WRAPPING A CLASSIFIER WITH OVERSAMPLING
# default component models for the wrapper:
naive = NaiveOversampler()
dummy = ConstantClassifier()
# we restrict to wrapping to `Probabilistic` models and so use
# `ProbablisticNetworkComposite` for the "exported" learning network type:
struct BalancedModel <:ProbabilisticNetworkComposite
model::Probabilistic
balancer # oversampler or undersampler
end
BalancedModel(; model=dummy, balancer=naive) =
BalancedModel(model, balancer)
BalancedModel(model; kwargs...) = BalancedModel(; model, kwargs...)
function MLJBase.prefit(over_sampled_model::BalancedModel, verbosity, _X, _y)
# the learning network:
X = source(_X)
y = source(_y)
mach1 = machine(:balancer) # `Static`, so no training arguments here
data = transform(mach1, X, y)
# `first` and `last` are overloaded for nodes, so we can do:
X_over = first(data)
y_over = last(data)
# we use the oversampled data for training:
mach2 = machine(:model, X_over, y_over)
# but consume new prodution data from the source:
yhat = predict(mach2, X)
# return the learning network interface:
return (; predict=yhat)
end
## DEMONSTRATION
# synthesize some synthetic data:
Xraw, yraw = make_moons(1000);
for_deletion = eachindex(yraw)[yraw .== 0][1:400]
to_keep = setdiff(eachindex(yraw), for_deletion)
X = Tables.rowtable(Xraw)[to_keep]
y = coerce(yraw[to_keep], OrderedFactor)
train, test = partition(eachindex(y), 0.6)
model = (@load DecisionTreeClassifier pkg=DecisionTree)()
balanced_model = BalancedModel(model)
# BalancedModel(
# model = DecisionTreeClassifier(
# max_depth = -1,
# min_samples_leaf = 1,
# min_samples_split = 2,
# min_purity_increase = 0.0,
# n_subfeatures = 0,
# post_prune = false,
# merge_purity_threshold = 1.0,
# display_depth = 5,
# feature_importance = :impurity,
# rng = Random._GLOBAL_RNG()),
# balancer = NaiveOversampler(
# ratio = 1.0))
mach = machine(balanced_model, X, y)
fit!(mach, rows=train)
predict(mach, rows=test[1:3])
# 3-element UnivariateFiniteVector{OrderedFactor{2}, String, UInt32, Float64}:
# UnivariateFinite{OrderedFactor{2}}(0=>1.0, 1=>0.0)
# UnivariateFinite{OrderedFactor{2}}(0=>0.0, 1=>1.0)
# UnivariateFinite{OrderedFactor{2}}(0=>0.0, 1=>1.0)
A large number of oversampling/undersampling strategies, with MLJ interfaces, are now provided by Imbalance.jl, and a wrapper, BalancedModel(model, ....)
, allowing insertion into supervised learning pipelines, is provided by MLJBalancing.jl.
Closing as complete.
cc @EssamWissam
https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html#over-sampling
edit (July 2023) An updated version of the POC below is later in this thread
This is just to kick off a discussion. I see oversampling/undersampling as transformers plus model wrappers. Here's a rough POC for this:
cc @DilumAluthge