Open ablaom opened 4 years ago
The other idea, mentioned on the call (which does not assume you have existing structs MLJFair.TruePositive
and so forth) would be a wrapper. So user does something like m = FairnessMetric(measure=TruePositive(), group_name=:gender)
and then m(yhat, y, X)
will return the dictionary of true positive counts, keyed on gender.
It would be nice if one could use the measures provided here in MLJ performance evaluation and elsewhere. This means implementing the API documented here, which does not appear to be the case.
Note that an MLJ measure does not have to return a numerical value. We regard, for example,
confmat
as a measure:So here's one tentative suggestion for implementing the MLJ API.
In MLJ one can already have a measure
m
with signaturem(yhat, y, X)
whereX
represents the full table of input features, which we can suppose is a Tables.jl table. In your case, you only care about one particular column ofX
- let's call it the group column - whose classes you want to filter on (eg, a column like["male", "female", "male", "male", "binary"]
). One could:Introduce a new parameter for each MLJFair metric, called
group_name
, or whatever, which specifies the name of the group column. So one would instantiate the measure like this:m = MLJFair.TruePositive(group_name=:gender)
.Overload calling of the metric appropriately, so that
m(yhat, y, X)
returns a dictionary of numerical values keyed on the "group" class, eg,Dict("male" => 2, "female" =>3, "binary" => 1)
. Or I suppose you could return a struct of some kind, but I think a dict would be the most user-friendly.To complete the API you may have to overload some measure traits, for example:
If you did this, then things like
evaluate(model, X, y, measure=MLJFair.TruePositive(group=gender), resampling=CV())
would work.How does this sound?