rust-ml / linfa

A Rust machine learning framework.
Apache License 2.0
3.67k stars 238 forks source link

Investigate discrepancy between non-BLAS and BLAS versions of `linfa-pls` #278

Open YuhanLiin opened 1 year ago

YuhanLiin commented 1 year ago

According to benchmark results from this comment, linfa-pls is slightly slower without BLAS. Specifically, the Regression-Nipals benchmarks are slightly slower when the sample size is 100000, and the Regression-Svd and Canonical-Svd benchmarks are slower when the sample size is 100000.

oojo12 commented 1 year ago

Flamegraphs from profiling attached. I haven't studied profiling or flamegraphs to build a firm foundation yet so I can't provide any insights. Each profile was run for 1min. Linfa_pls.zip

oojo12 commented 1 year ago

par_azip can be used in place of zip in our code as this is where a significant amount of time is spent.

Regression-Nipals-5feats-100_000 would benefit from the above the screenshots are attached with a black border around the area of interest.

nepals-5feat-100_000

similarly, Regression-Svd it appears that a_parzip should be added to linfa_linalg per this call ndarray::linalg::impl_linalg::general_mat_vec_mul_impl that leads to a `ndarray::zip...

linfa_pls::utils::outer spends a significant time with ndarray::zip::Zip<P,D>::inner there may be a faster alternative? Screenshot 2022-11-16 031715

Canonical-Svd looks like it would benefit from optimizing the linfa::param_guard::<impl linfa::traits::Fit<R,T,E> for P>::fit method. Screenshot 2022-11-16 031522

oojo12 commented 1 year ago

Our param_guard code

/// Performs checking step and calls `fit` on the checked hyperparameters. If checking failed, the
/// checking error is converted to the original error type of `Fit` and returned.
impl<R: Records, T, E, P: ParamGuard> Fit<R, T, E> for P
where
    P::Checked: Fit<R, T, E>,
    E: Error + From<crate::error::Error> + From<P::Error>,
{
    type Object = <<P as ParamGuard>::Checked as Fit<R, T, E>>::Object;

    fn fit(&self, dataset: &crate::DatasetBase<R, T>) -> Result<Self::Object, E> {
        let checked = self.check_ref()?;
        checked.fit(dataset)
    }
}

and its implementation for pls

impl<F: Float> ParamGuard for [<Pls $name Params>]<F> {
            type Checked = [<Pls $name ValidParams>]<F>;
            type Error = PlsError;

            fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
                if self.0.0.tolerance.is_negative() || self.0.0.tolerance.is_nan() || self.0.0.tolerance.is_infinite() {
                    Err(PlsError::InvalidTolerance(self.0.0.tolerance.to_f32().unwrap()))
                } else if self.0.0.max_iter == 0 {
                    Err(PlsError::ZeroMaxIter)
                } else {
                    Ok(&self.0)
                }
            }

            fn check(self) -> Result<Self::Checked, Self::Error> {
                self.check_ref()?;
                Ok(self.0)
            }
        }
YuhanLiin commented 1 year ago

The ParamGuard code only validates the hyperparameters, so it shouldn't be the bottleneck.

oojo12 commented 1 year ago

Hmm did I interpret the flamegraph wrong? I thought since it was at the top most of the CPU time was spent there or I guess it's possible that it's saying most of the time is spent on the ::fit method implementation for ParamGuard?

YuhanLiin commented 1 year ago

That's what I think is happening, have to look at it to make sure. All I know is that there's no way the ParamGuard code can be the bottleneck,