lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
90 stars 17 forks source link

Replace indexed_shape by a version that uses jit #519

Closed Michael-T-McCann closed 1 month ago

Michael-T-McCann commented 1 month ago

Followup to #517 that is both fewer lines of code and more flexible in that any indexing expression that is valid for jax arrays should automatically work.

bwohlberg commented 1 month ago

Clever use of jax.numpy.empty. What's the computational cost of redoing the jit on each call? Have you checked that it's not considerably slower than the version it replaces?

Michael-T-McCann commented 1 month ago

Timing with

import scico
import numpy as np
%timeit -n 1 -r 1 scico.numpy.util.indexed_shape((1000, 1000, 1000, 1000), (np.newaxis, ..., slice(0, 10), slice(15, 0, -1)))

Proposed version: ~10 ms, existing version, ~10 µs. So we we 1000x slower but these are small numbers.

bwohlberg commented 1 month ago

I hesitate to merge this PR due to the difference in execution time, and also because the new version of indexed_shape will consume potentially large amounts of memory if JIT is disabled for debugging. I propose the compromise in branch brendt/index_shape: keep the original function, and add this one as a potentially useful reference version that is not used by other package components.