dimforge / simba

Set of mathematical traits to facilitate the use of SIMD-based AoSoA (Array of Struct of Array) storage pattern.
Apache License 2.0
293 stars 29 forks source link

lanes associated const for SimdValue #53

Closed aentity closed 4 months ago

aentity commented 8 months ago

is possible we can have associated const LANES for wide types? Or for example in SimdValue:

trait SimdValue {
  const LANES: usize;
}

This will be very useful in generic context and working with arrays.

For example:

if we have:

type WIDE = simba::simd::WideF32x4;
type WIDET = f32;

for slice in data.chunks(WIDE::LANES) {
  let arr: [WIDET; WIDE::LANES] = slice.try_into().unwrap();
  let wide = WIDE::from(arr);
}

we can now easily switch the wide type, put in generics, or change wide amount to experiment for performance. today simba::simd::WideF32x4::lanes() is trait fn, and cannot be const.

if yes, i will attempt to make PR, thank you

audunska commented 8 months ago

For what it's worth, I've tried this approach in my own non-simba code, and it does not work very well. This is because rust does not support being generic over associated constants very well.

In fact, even using the associated constant in the trait definition fails: https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=786e5dca82fc4bc38e7c7b6db68dc537

Instead, I ended up passing the lane count around as a parameter instead. It's more clumsy but works. It goes something like this:

pub trait HasWide<const WIDTH: usize>: Sized + Clone + Copy + Num {
    type Wide: Sized
        + Clone
        + Copy
        + SimdRealField<Element = Self>
        + std::fmt::Debug
        + Field
        + From<[Self; WIDTH]>
        + Into<[Self; WIDTH]>;
}

impl HasWide<4>for f32 {
  type Wide = AutoSimd<[f32; 4]>; // Or WideF32x4, for example
}

Use it like this (just some arbitrary computation that probably makes no sense):

fn wide_computation<const WIDTH: usize, R: HasWide<WIDTH>>(input: r) -> R {
  let arr: [R; WIDTH] = std::array::from_fn(|n| R::from_superset(&(n as f64)) * R::pi());
  let val: R::Wide = arr.into();
  val.sin().simd_horizontal_sum()
}

fn main() {
  let r = wide_computation::<4, f32>(1.0);
  println!("{r}");
}