uclamii / model_tuner

A library to tune the hyperparameters of common ML models. Supports calibration and custom pipelines.
Apache License 2.0
3 stars 0 forks source link

created todo comment for stratification fix #29

Closed lshpaner closed 2 months ago

lshpaner commented 2 months ago

Description:

Currently, the train_val_test_split method allows for stratification either by y (stratify_y) or by specified columns (stratify_cols), but not both at the same time. There are use cases where stratification by both the target variable (y) and specific columns is necessary to ensure a balanced and representative split across different data segments.

Proposed Enhancement:

Modify the method to support simultaneous stratification by both y and stratify_cols. This can be achieved by combining the stratification keys or implementing logic that ensures both y and the specified columns are considered during the stratification process.

Current Method Implementation:

def train_val_test_split(
    self,
    X,
    y,
    stratify_y,
    train_size,
    validation_size,
    test_size,
    random_state,
    stratify_cols,
    calibrate,
):

    # if calibrate:
    #     X = X.join(self.dropped_strat_cols)
    # Determine the stratify parameter based on stratify and stratify_cols
    if stratify_cols:
        # Creating stratification columns out of stratify_cols list
        stratify_key = X[stratify_cols]
    elif stratify_y:
        stratify_key = y
    else:
        stratify_key = None

    if self.drop_strat_feat:
        self.dropped_strat_cols = X[self.drop_strat_feat]
        X = X.drop(columns=self.drop_strat_feat)

    X_train, X_valid_test, y_train, y_valid_test = train_test_split(
        X,
        y,
        test_size=1 - train_size,
        stratify=stratify_key,  # Use stratify_key here
        random_state=random_state,
    )

    # Determine the proportion of validation to test size in the remaining dataset
    proportion = test_size / (validation_size + test_size)

    if stratify_cols:
        strat_key_val_test = X_valid_test[stratify_cols]
    elif stratify_y:
        strat_key_val_test = y_valid_test
    else:
        strat_key_val_test = None

    # Further split (validation + test) set into validation and test sets
    X_valid, X_test, y_valid, y_test = train_test_split(
        X_valid_test,
        y_valid_test,
        test_size=proportion,
        stratify=strat_key_val_test,
        random_state=random_state,
    )

    return X_train, X_valid, X_test, y_train, y_valid, y_test
panas89 commented 2 months ago

Line 1207 function get_cross_validate(), variable stratify is redundant

panas89 commented 2 months ago

We need to make a note in the documentation that stratify_cols cannot be used when using cross_validation

panas89 commented 2 months ago

Checked with debugger code changes! works!

lshpaner commented 2 months ago

We need to make a note in the documentation that stratify_cols cannot be used when using cross_validation

done