Mojo-Numerics-and-Algorithms-group / NuMojo

NuMojo is a library for numerical computing in Mojo 🔥 similar to numpy in Python.
Apache License 2.0
86 stars 15 forks source link

Add stats function sumall, prodall, meanall. Improve performance of matmul function #42

Closed forFudan closed 2 months ago

forFudan commented 2 months ago

Add functions in stats

Add stats function sumall, prodall, meanall which return a scalar of all items in an array.

Set the default axis of sum, prod, mean to 0.

Add function matmul_parallelized_simd

Compared to matmul_parallelized, this function increase the size of the SIMD vector from the default width to 16. The purpose is to increase the performance via SIMD. The function reduces the execution time by ~50% compared to matmul_parallelized and matmul_tiled_unrolled_parallelized for large matrices.

Example: two 1000x1000 matrices, 100 times:

==================================================
Matmul paralelled CxC
0.13595043000000001 s

[[      -23.483741841786355     -15.701597076540736     26.473091851878582      ...    -31.718022704728931      1.6964271388761831      28.632800765314453      ]
 [      -43.233593257424936     17.049138931287565      32.618187250485185      ...    18.783659991176876       27.92076483390781       3.9659833391638917      ]
 [      47.14442157400314       -60.806358543680417     36.401417494961457      ...    -8.6906117952914155      -12.97988096709012      4.8113582294888371      ]
...
 [      55.72750917086443       -31.878635011324135     5.0798819981552237      ...    0.69397495367225526      13.068364470346129      -23.380195913699559     ]
 [      22.835549164834685      -24.886999375511106     -13.959557015048651     ...    43.575505939403563       56.757641404795976      -2.6661609432295559     ]
 [      -42.676020059673355     11.693076330239709      6.6692060208009742      ...    -16.198446613171239      -0.96172564005890637    -38.374990575773573     ]]
Shape: [1000, 1000]  DType: float64

==================================================
Matmul paralelled tiled unrolling CxC
0.14608930000000001 s

[[      -23.483741841786355     -15.701597076540736     26.473091851878582      ...    -31.718022704728931      1.6964271388761831      28.632800765314453      ]
 [      -43.233593257424936     17.049138931287565      32.618187250485185      ...    18.783659991176876       27.92076483390781       3.9659833391638917      ]
 [      47.14442157400314       -60.806358543680417     36.401417494961457      ...    -8.6906117952914155      -12.97988096709012      4.8113582294888371      ]
...
 [      55.72750917086443       -31.878635011324135     5.0798819981552237      ...    0.69397495367225526      13.068364470346129      -23.380195913699559     ]
 [      22.835549164834685      -24.886999375511106     -13.959557015048651     ...    43.575505939403563       56.757641404795976      -2.6661609432295559     ]
 [      -42.676020059673355     11.693076330239709      6.6692060208009742      ...    -16.198446613171239      -0.96172564005890637    -38.374990575773573     ]]
Shape: [1000, 1000]  DType: float64

==================================================
Matmul paralelled with larger SIMD vector CxC
0.069114350000000005 s

[[      -23.483741841786355     -15.701597076540736     26.473091851878582      ...    -31.718022704728931      1.6964271388761831      28.632800765314453      ]
 [      -43.233593257424936     17.049138931287565      32.618187250485185      ...    18.783659991176876       27.92076483390781       3.9659833391638917      ]
 [      47.14442157400314       -60.806358543680417     36.401417494961457      ...    -8.6906117952914155      -12.97988096709012      4.8113582294888371      ]
...
 [      55.72750917086443       -31.878635011324135     5.0798819981552237      ...    0.69397495367225526      13.068364470346129      -23.380195913699559     ]
 [      22.835549164834685      -24.886999375511106     -13.959557015048651     ...    43.575505939403563       56.757641404795976      -2.6661609432295559     ]
 [      -42.676020059673355     11.693076330239709      6.6692060208009742      ...    -16.198446613171239      -0.96172564005890637    -38.374990575773573     ]]
Shape: [1000, 1000]  DType: float64