yeomko22 / TIL

Today I learned
1 stars 0 forks source link

ML model configuration test #123

Open yeomko22 opened 2 years ago

yeomko22 commented 2 years ago
class PipelineWithConfig(SimplePipeline):
    def __init__(self, config):
        # Call the inherited SimplePipeline __init__ method first.
        super().__init__()
        # Pass in a config object which we use during the train method.
        self.config = config

    def train(self, algorithm=LogisticRegression):
        # note that we instantiate the LogisticRegression classifier 
        # with params from the pipeline config
        self.model = algorithm(solver=self.config.get('solver'),
                               multi_class=self.config.get('multi_class'))
        self.model.fit(self.X_train, self.y_train)
class TestIrisConfig(unittest.TestCase):
    def setUp(self):
        # We prepare the pipeline for use in the tests
        config = {'solver': 'lbfgs', 'multi_class': 'auto'}
        self.pipeline = PipelineWithConfig(config=config)
        self.pipeline.run_pipeline()

    def test_pipeline_config(self):
        # Given
        # fetch model config using sklearn get_params()
        # https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html#sklearn.base.BaseEstimator.get_params
        model_params = self.pipeline.model.get_params()

        # Then
        self.assertTrue(model_params['solver'] in ENABLED_MODEL_SOLVERS)