Open stefan-k opened 10 months ago
Fruit from a Discord discussion with @stefan-k:
One way to avoid having to specify many argmin_math
trait bounds seems to be (example for nalgebra
, unsure about ndarray
) constraining CostFunction
and Gradient
associated types like so:
fn generic<T, F>(cost: T, init_param: na::DVector<F>) -> na::DVector<F>
where
F: argmin::core::ArgminFloat + na::RealField + argmin_math::ArgminZero + std::iter::Sum + argmin_math::ArgminMul<na::DVector<F>, na::DVector<F>>,
T: argmin::core::CostFunction<Output = F, Param = na::DVector<F>> + argmin::core::Gradient<Gradient = na::DVector<F>, Param = na::DVector<F>>,
{
let linesearch = MoreThuenteLineSearch::new()
.with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())
.unwrap();
let solver = LBFGS::new(linesearch, 7);
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();
res.state().get_prev_best_param().unwrap().clone()
}
Omitting any of the associated type constraints Output = F, Param = ...
or Gradient = ..., Param = ...
causes compiler errors indicating 1) that F
or Param
(some composite of F
) cannot be matched against the CostFunction
or Gradient
associated types and 2) that a lot of argmin_math
(and more) bounds are unsatisfied.
This only works for the nalgebra backend, and not for ndarray
and vec
. The following does NOT compile:
ndarray:
fn optimize_generic5<T, F>(cost: T, init_param: Array1<F>) -> Array1<F>
where
F: ArgminFloat + ArgminZero + std::iter::Sum + ArgminMul<Array1<F>, Array1<F>>,
T: argmin::core::CostFunction<Output = F, Param = Array1<F>>
+ argmin::core::Gradient<Gradient = Array1<F>, Param = Array1<F>>,
{
let linesearch = MoreThuenteLineSearch::new()
.with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())
.unwrap();
let solver = LBFGS::new(linesearch, 7);
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();
res.state().get_prev_best_param().unwrap().clone()
}
vec:
fn optimize_generic7<T, F>(cost: T, init_param: Vec<F>) -> Vec<F>
where
F: ArgminFloat + ArgminZero + std::iter::Sum,
T: argmin::core::CostFunction<Output = F, Param = Vec<F>>
+ argmin::core::Gradient<Gradient = Vec<F>, Param = Vec<F>>,
{
let solver = ParticleSwarm::new((init_param, init_param), 40);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(100))
.run()?;
res.state().get_prev_best_param().unwrap().clone()
}
A potential reason for the difference between the backends could be the way the math traits are implemented on the data types. For instance, comparing the ArgminSignum
implementation for nalgebra
:
impl<N, R, C> ArgminSignum for OMatrix<N, R, C>
where
N: SimdComplexField,
R: Dim,
C: Dim,
DefaultAllocator: Allocator<N, R, C>,
{
#[inline]
fn signum(self) -> OMatrix<N, R, C> {
self.map(|v| v.simd_signum())
}
}
with ndarray
:
macro_rules! make_signum {
($t:ty) => {
impl ArgminSignum for Array1<$t> {
#[inline]
fn signum(mut self) -> Array1<$t> {
for a in &mut self {
*a = a.signum();
}
self
}
}
impl ArgminSignum for Array2<$t> {
#[inline]
fn signum(mut self) -> Array2<$t> {
let m = self.shape()[0];
let n = self.shape()[1];
for i in 0..m {
for j in 0..n {
self[(i, j)] = self[(i, j)].signum();
}
}
self
}
}
};
}
// [...]
make_signum!(f32);
make_signum!(f64);
For ndarray
, the traits are specifically implemented for the individual types (f32
, f64
, ...) whereas for nalgebra
it is generic over N
.
I remember trying to implement the math traits for Vec<F>
instead of Vec<f32>
, Vec<f64>
, Vec<Vec<f32>>
, Vec<Vec<f64>>
..., which didn't work because the implementation for Vec<Vec<F>>
could potentially clash with Vec<F>
. I don't recall the specifics though.
Wrapping calls to the
Executor
inside a function is typically not an issue when the types are known:However, if this function is to be generic over the float type, it becomes quite inconvenient:
I can only imagine that things get worse when the function is to be generic over the parameter vector itself.
I'm not sure at this point how to improve this situation. A few ideas I have:
Array1<F>: ArgminMathNdarray
or something like that would suffice. However, due toF
being generic, I'm not sure if this will workI'm open to further ideas on that topic :)