Xtra-Computing / thundersvm

ThunderSVM: A Fast SVM Library on GPUs and CPUs
Apache License 2.0
1.56k stars 217 forks source link

gridsearchcv fit error #280

Open thegreatesthoneybee opened 4 months ago

thegreatesthoneybee commented 4 months ago
strategy_list` = ['st1', 'st2', 'st3', 'st4', 'st5', 'st6', 'st7', 'st8', 'st9', 'st10']

for strategy in tqdm(strategy_list):
    class_weight = weight_instances(globals()['y_train_' + strategy])

    params = [
        {'C': c_list, 'kernel': ['linear'], 'class_weight': [class_weight], 'probability' : [1]},
        {'C': c_list, 'kernel': ['rbf'], 'gamma': gamma_list, 'class_weight': [class_weight], 'probability' : [1]},
        {'C': c_list, 'kernel': ['polynomial'], 'gamma': gamma_list, 'degree': dimension_list, 'class_weight': [class_weight], 'probability' : [1]}
    ]

    gs = GridSearchCV(SVC(n_jobs=-1), params, cv=10, refit=True,
                    scoring=kappa_scorer, verbose=3, error_score='raise')
    gs.fit(x_train, globals()['y_train_' + strategy])  

    dump(gs.best_estimator_, f'{folder_name}/{strategy}.svc') 
    print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print(f'best params: {gs.best_params_}')

I got the error below

Traceback (most recent call last):
  File "d:\SynologyDrive\withfox\Project\Strategy_portfolio\make_model_thundersvm.py", line 144, in <module>
    gs.fit(x_train, globals()['y_train_' + strategy])  
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\base.py", line 1473, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\model_selection\_search.py", line 968, in fit
    self._run_search(evaluate_candidates)
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\model_selection\_search.py", line 1543, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\model_selection\_search.py", line 914, in evaluate_candidates
    out = parallel(
          ^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\utils\parallel.py", line 67, in __call__
    return super().__call__(iterable_with_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\joblib\parallel.py", line 1918, in __call__
    return output if self.return_generator else list(output)
                                                ^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\joblib\parallel.py", line 1847, in _get_sequential_output
    res = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearnex\utils\parallel.py", line 46, in __call__
    return self.function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\model_selection\_validation.py", line 910, in _fit_and_score
    test_scores = _score(
                  ^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\metrics\_scorer.py", line 279, in __call__
    return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\metrics\_scorer.py", line 371, in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\metrics\_scorer.py", line 89, in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "C:\miniconda3\envs\ml\Lib\site-packages\sklearn\utils\_response.py", line 199, in _get_response_values
    classes = estimator.classes_
              ^^^^^^^^^^^^^^^^^^
AttributeError: 'SVC' object has no attribute 'classes_'. Did you mean: 'n_classes'?

When i removed error_score='raise', i got 'score=nan' meaasge .

Please help to solve this error.

Thanks