fre-hu / mdarray

Multidimensional array for Rust
Apache License 2.0
11 stars 0 forks source link

Quality of life features for dealing with shapes? #6

Open grothesque opened 1 week ago

grothesque commented 1 week ago

Shapes in mdarray are not as simple as in a completely dynamic library like Python's NumPy or Rust's ndarray. The "same" shape (2, 3) can be represented by the types DynRank, (Dyn, Dyn), (Const<2>, Dyn), (Dyn, Const<3>), (Const<2>, Const<3>). This is a good thing but IMHO it invites some quality-of-life features.

For example in the following

    let a = array![[0.0, 1.0], [1.0, 1.0]];
    let a: = a.reshape((Const::<2>, Dyn(2)));

the same shape must be repeated without need and creating potential for error.

Perhaps the following could be made to work?

    let a = array![[0.0, 1.0], [1.0, 1.0]];
    let a: View<f64, (Const::<2>, Dyn)> = a.into();

Additionally/alternatively, perhaps a placeholder value could be added (like -1 in NumPy) that represents any dimension that is necessary for the reshape to work, something like this (note the Dyn::Any)?

    let a = array![[0.0, 1.0], [1.0, 1.0]];
    let a: = a.reshape((Const::<2>, Dyn::Any));

(Not sure about how this would best work in practice though. That placeholder dimension should not be allowed in an actual shape.)

Here is a related case: It's easy to create a static rank array:

    let mut b = tensor![[0.0; 2]; 3];

But to do the same with a DynRank the simplest I can come up with is

    let mut b: mdarray::Tensor<f64> = tensor![[0.0; 2]; 3].reshape(DynRank::from_dims(&[2, 3])).into();

(is there a better way?).

How about adding a conversion that would allow to write

    let mut b: mdarray::Tensor<f64> = tensor![[0.0; 2]; 3].into();

?

fre-hu commented 6 days ago

Having a placeholder dimension in reshape is useful, and I previously though about using .. as placeholder. This is simpler though and avoids more trait/type definitions so I implemented it. It uses usize::MAX as placeholder, but can also be written as !0.

I'm looking into if the conversion functions can be improved. There is a limitation that implementing From to only change shape gives a collision with the reflexive variant From<T> for T. But it should be possible from/to primitive arrays and perhaps between different array types.

grothesque commented 6 days ago

Having a placeholder dimension in reshape is useful, and I previously though about using .. as placeholder. This is simpler though and avoids more trait/type definitions so I implemented it. It uses usize::MAX as placeholder, but can also be written as !0.

Very nice! The !0 seems like a good idea.

I'm looking into if the conversion functions can be improved. There is a limitation that implementing From to only change shape gives a collision with the reflexive variant From<T> for T. But it should be possible from/to primitive arrays and perhaps between different array types.

Here's another example where some form of conversion might be useful. The following program works, but uncommenting the commented-out section obviously doesn't compile. Not sure what amount of implicit/simple conversions would be appropriate here, but it feels like passing an array with shape DynRank to a function that expects an argument of shape (Dyn, Dyn) should be straightforward.

use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, expr::Expression};

// Indexing convention: C_ij <- A_ik * B_kj
fn matmul<D0: Dim, D1: Dim, D2: Dim>(
    a: &Slice<f64, (D0, D1)>,
    b: &Slice<f64, (D1, D2)>,
    c: &mut Slice<f64, (D0, D2)>) {
    for (mut ci, ai) in c.rows_mut().zip(a.rows()) {
        for (aik, bk) in ai.zip(b.rows()) {
            for (cij, bkj) in ci.expr_mut().zip(bk) {
                *cij = aik.mul_add(*bkj, *cij);
            }
        }
    }
}

fn main() {
    let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
    let b = array![[0.0, 1.0], [1.0, 1.0]];
    let b = b.reshape((Const::<2>, Dyn(2)));

    let mut c = tensor![[0.0; 2]; 3];

    // use mdarray::{DynRank, Shape};
    // let mut c: mdarray::Tensor<f64> = tensor![[0.0; 2]; 3].reshape(DynRank::from_dims(&[2, 3])).into();

    dbg!(std::any::type_name_of_val(&a));
    dbg!(std::any::type_name_of_val(&b));
    dbg!(std::any::type_name_of_val(&c));

    matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c);

    assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);
}