spacetelescope / stcal

https://stcal.readthedocs.io/en/latest/
Other
10 stars 32 forks source link

Optimization: use binary masks instead of indices during ramp fitting #168

Closed braingram closed 1 year ago

braingram commented 1 year ago

While looking at the ramp fitting code I noticed several instances where a binary mask was created, thennumpy.where used to generate indices that were then used to set values in an array like in the following example: https://github.com/spacetelescope/stcal/blob/3321069e833ce61d39a86b88ef4630517b22b286/src/stcal/ramp_fitting/utils.py#L1501-L1505

The binary mask can be used directly:

data_sect[np.bitwise_and(gdq_sect, ramp_data.flags_saturated).astype(bool)] = np.NaN

which will result in faster code (the modified code takes ~15% of the original code using a test case, included below, on my machine).

import time
import numpy

bad_flag = 0b1
n_bad = 100
shape = (10, 512, 512)

# generate a fixed number of bad pixels at random
# locations
size = numpy.prod(shape)
flat_arr = numpy.zeros(size, dtype='uint16')
flat_arr[:n_bad] |= bad_flag
numpy.random.shuffle(flat_arr)
dq = flat_arr.reshape(shape)
data = numpy.zeros(shape, dtype='float32')

def f1(data, dq):
    ws = numpy.where(numpy.bitwise_and(dq, bad_flag))
    data[ws] = numpy.NaN
    del ws
    return data

def f2(data, dq):
    data[numpy.bitwise_and(dq, bad_flag).astype(bool)] = numpy.NaN
    return data

def timeit(f):
    t0 = time.perf_counter()
    f()
    t1 = time.perf_counter()
    return t1 - t0

print("f1")
data[:] = 0
t = timeit(lambda: f1(data, dq))
print(f"f1 took {t}")

f1_data = data.copy()

print("f2")
data[:] = 0
t = timeit(lambda: f2(data, dq))
print(f"f2 took {t}")

numpy.testing.assert_array_equal(f1_data, data)