Ibotta / sk-dist

Distributed scikit-learn meta-estimators in PySpark
Apache License 2.0
285 stars 53 forks source link

Multi-Metric Evaluation with DistGridSearchCV results in NameError #41

Open davcem opened 4 years ago

davcem commented 4 years ago

Describe the bug Using multi-metric scoring in DistGridSearchCV results in an NameError: File "/home//.local/lib/python3.6/site-packages/skdist/distribute/search.py", line 315, in fit not isinstance(self.refit, six.string_types) or NameError: name 'six' is not defined

To Reproduce Steps to reproduce the behavior: Create a DistGridSearchCV:

GS_EVALUATION_METRICS_DICT = { 'accuracy' : 'accuracy', 'roc_auc' : 'roc_auc' }

model = GaussianNB() model_param_grid: {'var_smoothing': [1e-08, 0.0001, 0.01]}

grid_search = DistGridSearchCV(estimator=model, 
                           param_grid=model_param_grid,
                           sc=sc, 
                           scoring=GS_EVALUATION_METRICS_DICT, 
                           n_jobs=6, 
                           pre_dispatch=6,
                           cv=3, 
                           refit='roc_auc',
                           verbose=1,
                           error_score=0,
                           return_train_score=True,
                           )

Expected behavior A clear and concise description of what you expected to happen. No NameError

Additional context I think the error is easily fixable --> Add the import of the six library

denver1117 commented 3 years ago

Thanks for raising this issue. Would you be open to submitting a PR for this fix and adding test coverage? I'd happily and quickly approve a PR that resolves it. Thanks.