ezmsg-org / ezmsg

Pure-Python DAG-based high-performance SHM-backed pub-sub and multi-processing pattern
https://ezmsg.readthedocs.io/en/latest/
MIT License
15 stars 6 forks source link

For Discussion: custom `slice_along_axis` 9.3x faster than np.moveaxis back-and-forth #78

Closed cboulay closed 9 months ago

cboulay commented 10 months ago
import typing
import numpy as np
import numpy.typing as npt
import timeit

def slice_along_axis_1(in_arr: npt.NDArray, sl: typing.Union[slice, int], axis: int) -> npt.NDArray:
    all_slice = (slice(None),) * axis + (sl,) + (slice(None),) * (in_arr.ndim - axis - 1)
    return in_arr[all_slice]

def slice_along_axis_2(arr: npt.NDArray, sl: slice, axis: int) -> npt.NDArray:
    return np.moveaxis(np.moveaxis(arr, axis, 0)[sl], 0, axis)

if __name__ == "__main__":
    arr_shape = (100, 100, 100)
    test_arr = np.arange(np.prod(arr_shape)).reshape(arr_shape)

    %timeit slice_along_axis_1(test_arr, np.s_[::10], 0)  # 351 ns
    %timeit slice_along_axis_2(test_arr, np.s_[::10], 0)  # 3.29 µs
    %timeit slice_along_axis_1(test_arr, np.s_[::10], 1)  # 365 ns
    %timeit slice_along_axis_2(test_arr, np.s_[::10], 1)  # 3.37 µs
    %timeit slice_along_axis_1(test_arr, np.s_[::10], 2)  # 362 ns
    %timeit slice_along_axis_2(test_arr, np.s_[::10], 2)  # 3.42 µs