class PipelineWithConfig(SimplePipeline):
def __init__(self, config):
# Call the inherited SimplePipeline __init__ method first.
super().__init__()
# Pass in a config object which we use during the train method.
self.config = config
def train(self, algorithm=LogisticRegression):
# note that we instantiate the LogisticRegression classifier
# with params from the pipeline config
self.model = algorithm(solver=self.config.get('solver'),
multi_class=self.config.get('multi_class'))
self.model.fit(self.X_train, self.y_train)
SimplePipeline을 상속 받는 PipelineWithConfig 클래스를 만든다. 이 때 생성자에 config 객체를 전달해준다.
이는 PipelineWIthConfig 클래스가 필요로 하는 객체를 외부에서 전달받는 dependency injection이다.
class TestIrisConfig(unittest.TestCase):
def setUp(self):
# We prepare the pipeline for use in the tests
config = {'solver': 'lbfgs', 'multi_class': 'auto'}
self.pipeline = PipelineWithConfig(config=config)
self.pipeline.run_pipeline()
def test_pipeline_config(self):
# Given
# fetch model config using sklearn get_params()
# https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html#sklearn.base.BaseEstimator.get_params
model_params = self.pipeline.model.get_params()
# Then
self.assertTrue(model_params['solver'] in ENABLED_MODEL_SOLVERS)
이제 모델을 학습시킬 때 파라미터를 함께 전달해줘서 학습을 진행한다.
그 다음 model.get_params() 함수를 통해서 모델 파라미터를 가져오고, 여기서 원하는 config 값이 제대로 전달되었는 지를 테스트한다.