ashryaagr / Fairness.jl

Julia Toolkit with fairness metrics and bias mitigation algorithms
https://ashryaagr.github.io/Fairness.jl/dev/
MIT License
31 stars 14 forks source link

Implementing the MLJ measures API #1

Open ablaom opened 4 years ago

ablaom commented 4 years ago

It would be nice if one could use the measures provided here in MLJ performance evaluation and elsewhere. This means implementing the API documented here, which does not appear to be the case.

Note that an MLJ measure does not have to return a numerical value. We regard, for example, confmat as a measure:

using MLJ
X, y = @load_crabs
model = @load DecisionTreeClassifier
y = coerce([0, 0, 1, 1, 1], OrderedFactor)
e = evaluate(model, X, y, measure=confmat, operation=predict_mode)
julia> e.per_fold[1][1]
              ┌───────────────────────────┐
              │       Ground Truth        │
┌─────────────┼─────────────┬─────────────┤
│  Predicted  │      B      │      O      │
├─────────────┼─────────────┼─────────────┤
│      B      │     31      │      0      │
├─────────────┼─────────────┼─────────────┤
│      O      │      3      │      0      │
└─────────────┴─────────────┴─────────────┘

So here's one tentative suggestion for implementing the MLJ API.

In MLJ one can already have a measure m with signature m(yhat, y, X) where X represents the full table of input features, which we can suppose is a Tables.jl table. In your case, you only care about one particular column of X - let's call it the group column - whose classes you want to filter on (eg, a column like ["male", "female", "male", "male", "binary"]). One could:

  1. Introduce a new parameter for each MLJFair metric, called group_name, or whatever, which specifies the name of the group column. So one would instantiate the measure like this: m = MLJFair.TruePositive(group_name=:gender).

  2. Overload calling of the metric appropriately, so that m(yhat, y, X) returns a dictionary of numerical values keyed on the "group" class, eg, Dict("male" => 2, "female" =>3, "binary" => 1). Or I suppose you could return a struct of some kind, but I think a dict would be the most user-friendly.

  3. To complete the API you may have to overload some measure traits, for example:

MLJBase.name(::Type{<:MLJFair.TruePositive}) = "TruePositive"
MLJBase.target_scitype( ... ) = OrderedFactor{2}
MLJBase.supports_weights(...) = false # for now
MLJBase.prediction_type(..) = :deterministic
MJLBase.orientation(::Type) = :other # other options are :score, :loss
MLJBase.reports_each_observation(::Type) = false
MLJBase.aggregation(::Type) = Sum()  
MLJBase.is_feature_dependent(::Type) = true              <---- Important

If you did this, then things like evaluate(model, X, y, measure=MLJFair.TruePositive(group=gender), resampling=CV()) would work.

How does this sound?

ablaom commented 4 years ago

The other idea, mentioned on the call (which does not assume you have existing structs MLJFair.TruePositive and so forth) would be a wrapper. So user does something like m = FairnessMetric(measure=TruePositive(), group_name=:gender) and then m(yhat, y, X) will return the dictionary of true positive counts, keyed on gender.