gbrammer / grizli

Grizli: The "Grism redshift and line" analysis software
MIT License
67 stars 51 forks source link

NIRCam hot pixel flagging #223

Closed gbrammer closed 4 months ago

gbrammer commented 4 months ago

Add a function grizli.jwst_utils.flag_nircam_hot_pixels to

  1. Flag isolated hot pixels defined to be pixels that exceed a S/N threshold and whose neighbors have S/N less than some tolerance.
  2. Flag "plusses" around some known hot pixels

This is run in utils.drizzle_from_visit (inherited by aws.visit_processor.cutout_mosaic). If such pixels are identified, it then skips adding the static bad pixel tables, where the latter somewhat conservatively flag a cumulative list of bad pixels.

Also added is an option to include an expanded specified list of JWST bad pixel flags in the mask. These are both turned on by default and are controlled with the following

JWST_DQ_FLAGS = [
    "DO_NOT_USE",
    "OTHER_BAD_PIXEL",
    "UNRELIABLE_SLOPE",
    "UNRELIABLE_BIAS",
    "NO_SAT_CHECK",
    "NO_GAIN_VALUE",
    "HOT",
    "WARM",
    "DEAD",
    "RC",
    "LOW_QE",
]

# Set either option below to None to turn off
grizli.utils.drizzle_from_visit(...,
    jwst_dq_flags=JWST_DQ_FLAGS,
    nircam_hot_pixel_kwargs={},
)

Demo:

import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits

from grizli.jwst_utils import flag_nircam_hot_pixels

signal = np.zeros((48,48), dtype=np.float32)

# hot
signal[16,16] = 10

# plus
for off in [-1,1]:
    signal[32+off, 32] = 10
    signal[32, 32+off] = 7

err = np.ones_like(signal)
np.random.seed(1)
noise = np.random.normal(size=signal.shape)*err

dq = np.zeros(signal.shape, dtype=int)
dq[32,32] = 2048 # HOT

header = pyfits.Header()
header['MDRIZSKY'] = 0.

hdul = pyfits.HDUList([
    pyfits.ImageHDU(data=signal+noise, name='SCI', header=header),
    pyfits.ImageHDU(data=err, name='ERR'),
    pyfits.ImageHDU(data=dq, name='DQ'),
])

sn, dq_flag, count = flag_nircam_hot_pixels(hdul)

fig, axes = plt.subplots(1,2,figsize=(8,4), sharex=True, sharey=True)

axes[0].imshow(signal + noise, vmin=-2, vmax=9, cmap='gray')
axes[0].set_xlabel('Simulated data')
axes[1].imshow(dq_flag, cmap='magma')
axes[1].set_xlabel('Flagged pixels')

for ax in axes:
    ax.set_xticklabels([])
    ax.set_yticklabels([])

fig.tight_layout(pad=1)

nircam_hotpix