Mojo-Numerics-and-Algorithms-group / NuMojo

NuMojo is a library for numerical computing in Mojo 🔥 similar to numpy in Python.
Apache License 2.0
86 stars 15 forks source link

[lib+test+doc] Get item by List[Int] or NDArray[DType.index] #77

Closed forFudan closed 1 month ago

forFudan commented 1 month ago

This PR resolves Effort 1(b) of #68. It allows users to get items by List[Int] or NDArray[DType.index]. It now returns a non-flattened array. The behavior is aligned with numpy, which returns an array that has the same ndim with the input array, but only differs in the shape of the first dimension. See examples below.

Example:

> var A = nm.NDArray[nm.i8](3,random=True)
> print(A)
[       14      97      -59     ]
1-D array  Shape: [3]  DType: int8
>
> print(A[List[Int](2,1,0,1,2)])
[       -59     97      14      97      -59     ]
1-D array  Shape: [5]  DType: int8
>
> var B = nm.NDArray[nm.i8](3, 3,random=True)
> print(B)
[[      -4      112     -94     ]
[      -48     -40     66      ]
[      -2      -94     -18     ]]
2-D array  Shape: [3, 3]  DType: int8
>
> print(B[List[Int](2,1,0,1,2)])
[[      -2      -94     -18     ]
[      -48     -40     66      ]
[      -4      112     -94     ]
[      -48     -40     66      ]
[      -2      -94     -18     ]]
2-D array  Shape: [5, 3]  DType: int8
>
> var C = nm.NDArray[nm.i8](3, 3, 3,random=True)
> print(C)
[[[     -126    -88     -79     ]
[     14      78      99      ]
[     -32     3       -42     ]]
[[     56      -45     -71     ]
[     -13     18      -102    ]
[     4       83      26      ]]
[[     61      -73     86      ]
[     -125    -84     66      ]
[     32      21      53      ]]]
3-D array  Shape: [3, 3, 3]  DType: int8
>
> print(C[List[Int](2,1,0,1,2)])
[[[     61      -73     86      ]
[     -125    -84     66      ]
[     32      21      53      ]]
[[     56      -45     -71     ]
[     -13     18      -102    ]
[     4       83      26      ]]
[[     -126    -88     -79     ]
[     14      78      99      ]
[     -32     3       -42     ]]
[[     56      -45     -71     ]
[     -13     18      -102    ]
[     4       83      26      ]]
[[     61      -73     86      ]
[     -125    -84     66      ]
[     32      21      53      ]]]
3-D array  Shape: [5, 3, 3]  DType: int8

The test file is also added.