Closed cboulay closed 9 months ago
import typing import numpy as np import numpy.typing as npt import timeit def slice_along_axis_1(in_arr: npt.NDArray, sl: typing.Union[slice, int], axis: int) -> npt.NDArray: all_slice = (slice(None),) * axis + (sl,) + (slice(None),) * (in_arr.ndim - axis - 1) return in_arr[all_slice] def slice_along_axis_2(arr: npt.NDArray, sl: slice, axis: int) -> npt.NDArray: return np.moveaxis(np.moveaxis(arr, axis, 0)[sl], 0, axis) if __name__ == "__main__": arr_shape = (100, 100, 100) test_arr = np.arange(np.prod(arr_shape)).reshape(arr_shape) %timeit slice_along_axis_1(test_arr, np.s_[::10], 0) # 351 ns %timeit slice_along_axis_2(test_arr, np.s_[::10], 0) # 3.29 µs %timeit slice_along_axis_1(test_arr, np.s_[::10], 1) # 365 ns %timeit slice_along_axis_2(test_arr, np.s_[::10], 1) # 3.37 µs %timeit slice_along_axis_1(test_arr, np.s_[::10], 2) # 362 ns %timeit slice_along_axis_2(test_arr, np.s_[::10], 2) # 3.42 µs