smartcorelib / smartcore

A comprehensive library for machine learning and numerical computing. The library provides a set of tools for linear algebra, numerical computing, optimization, and enables a generic, powerful yet still efficient approach to machine learning.
https://smartcorelib.org/
Apache License 2.0
698 stars 75 forks source link

KFold for SVC in 0.3.0 #247

Open Luosuu opened 1 year ago

Luosuu commented 1 year ago

I'm submitting a

not sure if I misused the function or this is a feature not supported yet.

Current Behaviour:

use smartcore::api::SupervisedEstimatorBorrow;
let results = cross_validate(
    SVC::new(),
    &x,
    &y,
    Default::default(),
    &KFold::default().with_n_splits(3),
    &accuracy,
)
.unwrap();

error

the trait `SupervisedEstimator<_, _, _>` is not implemented for `SVC<'_, _, _, _, _>`

Snapshot:

Screenshot from 2022-11-21 12-48-07

Environment:

* rustc version 1.65.0 (897e37553 2022-11-02) * cargo version 1.65.0 (4bc8f24d3 2022-10-20) * Pop!_OS 22.04 ### Do you want to work on this issue?
Mec-iS commented 1 year ago

Hi, thanks for using smartcore and finding this.

It looks like at the moment cross_validate works only for types that implement SupervisedEstimator trait; SVC implement SupervisedEstimatorBorrow instead, that is a variant that allows saving memory allocation. The solution can be:

  1. or we implement SupervisedEstimator for SVC in place of SupervisedEstimatorBorrow
  2. or we change cross_validate to accept SupervisedEstimatorBorrow and implement it for the other models (linear regression, ...); this way SupervisedEstimator will be used only with cross_val_predict (because it actually need a predict method).
  3. or we develop a cross_validate_borrow function that takes types that implement SupervisedEstimatorBorrow

SVC was not meant to be used with cross_validate in the first place as we didn't have a use case for that. At the moment cross_validate works only with estimators that have also Predictor trait (SupervisedEstimator is composed with Predictor while the "borrow" version is not). There is note about this in the documentation but I admit it should be referenced in the cross_validate documentation.

If you have any idea about how to fix this it would be great to hear, you can also open a Pull Request to fix it. If it is indeed deemed useful and we find a solution it will be released with the next version.

NOTE: model_selection module is missing the documentation, it would be also great if you can provide it.