rust-ml / linfa

A Rust machine learning framework.
Apache License 2.0
3.75k stars 245 forks source link

Method `fit` exists for struct `SvmParams<F, F>`, but its trait bounds were not satisfied #362

Open ronniec95 opened 1 week ago

ronniec95 commented 1 week ago

I have a simple struct for setting some params and creating an SVR model

use linfa::prelude::*;
use linfa_svm::{Svm, SvmParams};
use ndarray::Array;
struct SVRModel {
    params: SvmParams<f64,f64>,
    model: Option<Svm<f64,f64>>,
}

impl SVRModel
{
    fn new() -> Self {
        Self {
            params: Svm::<f64, _>::params()
                .nu_eps(0.5,0.01)
                .gaussian_kernel(95.0),
            model: None,
        }
    }

    fn train(&mut self, x_train: &[&[f64]], y_train: &[f64]) {
        let x_train = x_train
            .iter()
            .map(|x| x.to_vec())
            .flatten()
            .collect::<Vec<_>>();
        let targets = y_train.iter().cloned().collect::<Vec<_>>();

        let dataset = DatasetBase::new(
            Array::from_shape_vec([targets.len(), x_train.len()], x_train).unwrap(),
            Array::from_shape_vec([targets.len()], targets).unwrap(),
        );

        self.model = Some(self.params.fit(&dataset).unwrap());
    }
}

The above works fine but changing the type to F: Float like this

use linfa::prelude::*;
use linfa_svm::{Svm, SvmParams};
use ndarray::Array;
struct SVRModel<F: Float> {
    params: SvmParams<F, F>,
    model: Option<Svm<F, F>>,
}

impl<F> SVRModel<F>
where
    F: linfa::Float,
{
    fn new() -> Self {
        Self {
            params: Svm::<F, F>::params()
                .nu_eps(F::from_f64(0.5).unwrap(), F::from_f64(0.01).unwrap())
                .gaussian_kernel(F::from_f64(95.0).unwrap()),
            model: None,
        }
    }

    fn train(&mut self, x_train: &[&[F]], y_train: &[F]) {
        let x_train = x_train
            .iter()
            .map(|x| x.to_vec())
            .flatten()
            .collect::<Vec<_>>();
        let targets = y_train.iter().cloned().collect::<Vec<_>>();

        let dataset = DatasetBase::new(
            Array::from_shape_vec([targets.len(), x_train.len()], x_train).unwrap(),
            Array::from_shape_vec([targets.len()], targets).unwrap(),
        );

        self.model = Some(self.params.fit(&dataset).unwrap());
    }
}

errors with

the method `fit` exists for struct `SvmParams<F, F>`, but its trait bounds were not satisfied
the following trait bounds were not satisfied:
`SvmValidParams<F, F>: linfa::prelude::Fit<_, _, _>`
which is required by `SvmParams<F, F>: linfa::prelude::Fit<_, _, _>`rustc[Click for full compiler diagnostic](rust-analyzer-diagnostics-view:/diagnostic message [15]?15#file:///d%3A/muCapital/systems/src/arti_xg.rs)
hyperparams.rs(37, 1): doesn't satisfy `SvmValidParams<F, F>: linfa::prelude::Fit<_, _, _>`
hyperparams.rs(69, 1): doesn't satisfy `SvmParams<F, F>: linfa::prelude::Fit<_, _, _>`

How do I express the missing trait bounds or get this to work with numeric types?

EDIT: Example rewritten to minimal self contained sample

ronniec95 commented 1 week ago

After eventually deciphering the error message it's essentially saying that the ParamGuard trait is not implemented for all the SvmValidParams

Fix is (probably) In hyperparameters.rs

impl<F: Float, O> ParamGuard for SvmValidParams<F, O> {
    type Checked = SvmValidParams<F, O>;
    type Error = SvmError;

    fn check_ref(&self) -> Result<&Self::Checked, SvmError> {
        Ok(&self)
    }

    fn check(self) -> Result<Self::Checked, SvmError> {
        self.check_ref()?;
        Ok(self)
    }
}

This is a true fix as really it should call all the parameters (kernal,solver,platt) and check them - which need their own implementations of ParamGuard; but I'm not clever enough to do those.