This implements the black box label-shift predictor as described here and also implements empirical risk minimization to fix a shift if it is detected. The important assumption here is that we have a pure label shift in the data, while the label conditionals are constant that is p(covariates|label) = q(covariates|label), where p and q are the training and production distributions, respectively.
Shift detection
The API can be best seen in test_simple_imputer_label_shift() inside test_simple_imputer.py.
Different from previous discussions, we don't need the test or training data to detect a label shift, because confusion matrices and marginals are sufficient and already stored as part of the imputer metrics.
Thus, after a simple imputer is trained we can now simply call:
This will a return a dictionary with labels as keys and weights for ERM as values. It will also sent a message to the logger as follows:
The estimated true label marginals are [('label_0', 0.62), ('label_1', 0.38)]
Marginals in the training data are [('label_0', 0.23), ('label_1', 0.77)]
Reweighing factors for empirical risk minimization{'label_0': 2.72, 'label_1': 0.49}
The smallest eigenvalue of the confusion matrix is 0.21 ' (needs to be > 0).
Empirical risk minimization and fixing label shift
To perform empirical risk minimization we need to retrain the model with the computed class weights. This is done by passing the weights dictionary as command line argument to the .fit() method of the SimpleImputer.
simple_imputer.fit(train_df, class_weights)
Training the imputation model performs empirical risk minimization by weighing every observation's contribution to the log-likelihood. Alternatively we can pass a list with weights for every single instance of the training data, e.g. if we believe that particular observations are more reliable
This implements the black box label-shift predictor as described here and also implements empirical risk minimization to fix a shift if it is detected. The important assumption here is that we have a pure label shift in the data, while the label conditionals are constant that is p(covariates|label) = q(covariates|label), where p and q are the training and production distributions, respectively.
Shift detection
The API can be best seen in
test_simple_imputer_label_shift()
insidetest_simple_imputer.py
. Different from previous discussions, we don't need the test or training data to detect a label shift, because confusion matrices and marginals are sufficient and already stored as part of the imputer metrics.Thus, after a simple imputer is trained we can now simply call:
This will a return a dictionary with labels as keys and weights for ERM as values. It will also sent a message to the logger as follows:
Empirical risk minimization and fixing label shift
To perform empirical risk minimization we need to retrain the model with the computed class weights. This is done by passing the weights dictionary as command line argument to the
.fit()
method of the SimpleImputer.Training the imputation model performs empirical risk minimization by weighing every observation's contribution to the log-likelihood. Alternatively we can pass a list with weights for every single instance of the training data, e.g. if we believe that particular observations are more reliable