argmin-rs / argmin

Numerical optimization in pure Rust
Apache License 2.0
1k stars 79 forks source link

Inconvenience when using Executor inside a function with generics #377

Open stefan-k opened 10 months ago

stefan-k commented 10 months ago

Wrapping calls to the Executor inside a function is typically not an issue when the types are known:

fn optimize(cost: Rosenbrock, init_param: Array1<f64>) -> Array1<f64> {
    let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();

    let solver = LBFGS::new(linesearch, 7);

    let res = Executor::new(cost, solver)
        .configure(|state| state.param(init_param).max_iters(100))
        .add_observer(SlogLogger::term(), ObserverMode::Always)


However, if this function is to be generic over the float type, it becomes quite inconvenient:

fn optimize_generic<F>(cost: Rosenbrock2<F>, init_param: Array1<F>) -> Array1<F>
    F: ArgminFloat + ArgminZero + std::iter::Sum + ArgminMul<Array1<F>, Array1<F>>,
    Array1<F>: ArgminAdd<Array1<F>, Array1<F>>,
    Array1<F>: ArgminAdd<F, Array1<F>>,
    Array1<F>: ArgminSub<F, Array1<F>>,
    Array1<F>: ArgminSub<Array1<F>, Array1<F>>,
    Array1<F>: ArgminDot<Array1<F>, F>,
    Array1<F>: ArgminMul<Array1<F>, Array1<F>>,
    Array1<F>: ArgminMul<F, Array1<F>>,
    Array1<F>: ArgminL1Norm<F>,
    Array1<F>: ArgminL2Norm<F>,
    Array1<F>: ArgminSignum,
    Array1<F>: ArgminMinMax,
    let linesearch = MoreThuenteLineSearch::new()
        // Parameters passed to functions also need to be of type `F`!
        .with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())

    let solver = LBFGS::new(linesearch, 7);

    let res = Executor::new(cost, solver)
        .configure(|state| state.param(init_param).max_iters(100))
        .add_observer(SlogLogger::term(), ObserverMode::Always)


I can only imagine that things get worse when the function is to be generic over the parameter vector itself.

I'm not sure at this point how to improve this situation. A few ideas I have:

  1. Create a "supertrait" per backend, such that Array1<F>: ArgminMathNdarray or something like that would suffice. However, due to F being generic, I'm not sure if this will work
  2. Create a general supertrait which covers all math traits. This probably isn't very useful though because not all backends implement all math traits.

I'm open to further ideas on that topic :)

jan-grimo commented 9 months ago

Fruit from a Discord discussion with @stefan-k:

One way to avoid having to specify many argmin_math trait bounds seems to be (example for nalgebra, unsure about ndarray) constraining CostFunction and Gradient associated types like so:

fn generic<T, F>(cost: T, init_param: na::DVector<F>) -> na::DVector<F>
    F: argmin::core::ArgminFloat + na::RealField + argmin_math::ArgminZero + std::iter::Sum + argmin_math::ArgminMul<na::DVector<F>, na::DVector<F>>,
    T: argmin::core::CostFunction<Output = F, Param = na::DVector<F>> + argmin::core::Gradient<Gradient = na::DVector<F>, Param = na::DVector<F>>,
    let linesearch = MoreThuenteLineSearch::new()
        .with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())

    let solver = LBFGS::new(linesearch, 7);

    let res = Executor::new(cost, solver)
        .configure(|state| state.param(init_param).max_iters(100))

with `Cargo.toml` entries ```toml argmin-math = { version = "0.3", features = ["nalgebra_latest-serde"] } argmin = "0.8" nalgebra = { version = "0.32", features = ["rand", "serde-serialize", "rayon"] } nalgebra-lapack = "0.24" ```

Omitting any of the associated type constraints Output = F, Param = ... or Gradient = ..., Param = ... causes compiler errors indicating 1) that F or Param (some composite of F) cannot be matched against the CostFunction or Gradient associated types and 2) that a lot of argmin_math (and more) bounds are unsatisfied.

stefan-k commented 9 months ago

This only works for the nalgebra backend, and not for ndarray and vec. The following does NOT compile:


fn optimize_generic5<T, F>(cost: T, init_param: Array1<F>) -> Array1<F>
    F: ArgminFloat + ArgminZero + std::iter::Sum + ArgminMul<Array1<F>, Array1<F>>,
    T: argmin::core::CostFunction<Output = F, Param = Array1<F>>
        + argmin::core::Gradient<Gradient = Array1<F>, Param = Array1<F>>,
    let linesearch = MoreThuenteLineSearch::new()
        .with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())

    let solver = LBFGS::new(linesearch, 7);

    let res = Executor::new(cost, solver)
        .configure(|state| state.param(init_param).max_iters(100))



fn optimize_generic7<T, F>(cost: T, init_param: Vec<F>) -> Vec<F>
    F: ArgminFloat + ArgminZero + std::iter::Sum,
    T: argmin::core::CostFunction<Output = F, Param = Vec<F>>
        + argmin::core::Gradient<Gradient = Vec<F>, Param = Vec<F>>,
    let solver = ParticleSwarm::new((init_param, init_param), 40);

    let res = Executor::new(cost, solver)
        .configure(|state| state.max_iters(100))

stefan-k commented 9 months ago

A potential reason for the difference between the backends could be the way the math traits are implemented on the data types. For instance, comparing the ArgminSignum implementation for nalgebra:

impl<N, R, C> ArgminSignum for OMatrix<N, R, C>
    N: SimdComplexField,
    R: Dim,
    C: Dim,
    DefaultAllocator: Allocator<N, R, C>,
    fn signum(self) -> OMatrix<N, R, C> {|v| v.simd_signum())

with ndarray:

macro_rules! make_signum {
    ($t:ty) => {
        impl ArgminSignum for Array1<$t> {
            fn signum(mut self) -> Array1<$t> {
                for a in &mut self {
                    *a = a.signum();

        impl ArgminSignum for Array2<$t> {
            fn signum(mut self) -> Array2<$t> {
                let m = self.shape()[0];
                let n = self.shape()[1];
                for i in 0..m {
                    for j in 0..n {
                        self[(i, j)] = self[(i, j)].signum();

// [...]


For ndarray, the traits are specifically implemented for the individual types (f32, f64, ...) whereas for nalgebra it is generic over N.

I remember trying to implement the math traits for Vec<F> instead of Vec<f32>, Vec<f64>, Vec<Vec<f32>>, Vec<Vec<f64>> ..., which didn't work because the implementation for Vec<Vec<F>> could potentially clash with Vec<F>. I don't recall the specifics though.