Closed Michael-T-McCann closed 1 month ago
Clever use of jax.numpy.empty
. What's the computational cost of redoing the jit on each call? Have you checked that it's not considerably slower than the version it replaces?
Timing with
import scico
import numpy as np
%timeit -n 1 -r 1 scico.numpy.util.indexed_shape((1000, 1000, 1000, 1000), (np.newaxis, ..., slice(0, 10), slice(15, 0, -1)))
Proposed version: ~10 ms, existing version, ~10 µs. So we we 1000x slower but these are small numbers.
I hesitate to merge this PR due to the difference in execution time, and also because the new version of indexed_shape
will consume potentially large amounts of memory if JIT is disabled for debugging. I propose the compromise in branch brendt/index_shape
: keep the original function, and add this one as a potentially useful reference version that is not used by other package components.
Followup to #517 that is both fewer lines of code and more flexible in that any indexing expression that is valid for jax arrays should automatically work.