lsst-epo / citizen-science-notebooks

A collection Jupyter notebooks that can be used to associate Rubin Science Platform data to a Zooniverse citizen science project.
3 stars 1 forks source link

Calexp images are misaligned in WCS #87

Open beckynevin opened 5 months ago

beckynevin commented 5 months ago

Describe the bug This is an issue with the WCS in the variable stars notebook (02) example where we send a flipbook of images to zooniverse. I'm currently sending five images and while they all appear to have the same number of pixels and the same pixel centers, the WCS coordinates are misaligned. This is due to an

image image

To Reproduce Here are the helper functions within the 02 notebook:

def update_wcs_center(wcs, new_center_sky):
    """
    Update the CRVAL values in the FITS header

    Parameters
    ----------
    wcs: world coordinate system from calexp image from lsst.butler
    new_center_sky: new coordinate center

    Returns
    -------
    updated WCS
    """
    header = wcs.getFitsMetadata()
    header['CRVAL1'] = new_center_sky.getLongitude().asDegrees()
    header['CRVAL2'] = new_center_sky.getLatitude().asDegrees()
    new_wcs = WCS(header)
    return new_wcs

def set_wcs_ticks_labels(ax, wcs):
    """
    Explicitly set tick positions and labels for the WCS axes
    d. is degrees and .dd is the number of decimal points to display

    Parameters
    ----------
    ax: axes object
    wcs: world coordinate system from calexp image from lsst.butler

    Returns
    -------
    updated axes labels and tick positions
    """
    ax.coords[0].set_major_formatter('d.ddd')
    # positions on bottom left
    ax.coords[0].set_ticks_position('bl')
    ax.coords[0].set_axislabel('Right Ascension')

    ax.coords[1].set_major_formatter('d.ddd')
    ax.coords[1].set_ticks_position('bl')
    ax.coords[1].set_axislabel('Declination')

    # Set the maximum number of ticks for both axes
    ax.coords[0].set_ticks(spacing=2*u.arcsec)
    ax.coords[1].set_ticks(spacing=2*u.arcsec)

# from tutorial 03a:
def cutout_calexp(butler,
                  ra,
                  dec,
                  visit,
                  detector,
                  cutoutsidelength=51,
                  **kwargs):
    """
    Produce a cutout from a calexp at the given ra, dec position.

    Adapted from cutout_coadd which was adapted from a DC2 tutorial
    notebook by Michael Wood-Vasey.

    Parameters
    ----------
    butler: lsst.daf.persistence.Butler
        Helper object providing access to a data repository
    ra: float
        Right ascension of the center of the cutout, in degrees
    dec: float
        Declination of the center of the cutout, in degrees
    visit: int
        Visit id of the calexp's visit
    detector: int
        Detector for the calexp
    cutoutsidelength: float [optional]
        Size of the cutout region in pixels.

    Returns
    -------
    MaskedImage: cutout image
    """
    dataid = {'visit': visit, 'detector': detector}
    print('ra', ra, 'dec', dec)
    radec = geom.SpherePoint(ra,
                             dec,
                             geom.degrees)
    cutoutsize = geom.ExtentI(cutoutsidelength,
                              cutoutsidelength)
    calexp_wcs = butler.get('calexp.wcs',
                            **dataid)
    xy = geom.PointI(calexp_wcs.skyToPixel(radec))
    bbox = geom.BoxI(xy - cutoutsize // 2,
                     cutoutsize)
    parameters = {'bbox': bbox}
    cutout_image = butler.get('calexp',
                              parameters=parameters,
                              **dataid)
    return cutout_image

def make_calexp_fig(cutout_image, out_name):
    """
    Create a figure of a calexp image
    should be followed with remove_figure

    Parameters
    ----------
    cutout_image : cutout_image from butler.get
    out_name : file name where you'd like to save it

    Returns
    ----------
    cutout figure
    """

    # Extract the WCS from the cutout image
    wcs = cutout_image.getWcs()

    # Get the CRVAL values from the WCS metadata
    crval1 = wcs.getFitsMetadata()['CRVAL1']
    crval2 = wcs.getFitsMetadata()['CRVAL2']
    # Create a new SpherePoint for the center of the image
    center_sky = geom.SpherePoint(crval1,
                                  crval2,
                                  geom.degrees)
    # Modify the center (for example, shift by 1 degree)
    new_center_sky = geom.SpherePoint(center_sky.getLongitude(),
                                      #+ 1.0*geom.degrees,
                                      center_sky.getLatitude())
                                      #+ 1.0*geom.degrees)
    # Update the WCS with the new center
    new_wcs = update_wcs_center(wcs,
                                new_center_sky)

    fig = plt.figure()
    ax = plt.subplot(projection=new_wcs)
    calexp_extent = (cutout_image.getBBox().beginX,
                     cutout_image.getBBox().endX,
                     cutout_image.getBBox().beginY,
                     cutout_image.getBBox().endY)
    im = ax.imshow(abs(cutout_image.image.array),
                   cmap='gray',
                   extent=calexp_extent,
                   origin='lower',
                   norm=matplotlib.colors.LogNorm(vmin=1e1, vmax=1e5)
                   )
    plt.colorbar(im, location='right', anchor=(0, 0.1))
    set_wcs_ticks_labels(ax, new_wcs)
    #plt.axis('off')
    plt.savefig(out_name)
    print('shape of image', np.shape(cutout_image.image.array))
    return fig

def remove_figure(fig):
    """
    Remove a figure to reduce memory footprint.

    Parameters
    ----------
    fig: matplotlib.figure.Figure
        Figure to be removed.

    Returns
    -------
    None
    """
    for ax in fig.get_axes():
        for im in ax.get_images():
            im.remove()
    fig.clf()
    plt.close(fig)
    gc.collect()

Run this cell:

batch_dir = './variable_stars_output/'
figout_data = {"sourceId": diaobjectid}
cutouts = []

for i, idx in enumerate(idx_select):
    star_ra = sorted_sources['ra'][idx]
    star_dec = sorted_sources['decl'][idx]
    star_visitid = sorted_sources['visitId'][idx]
    star_detector = sorted_sources['detector'][idx]
    star_id = sorted_sources['diaObjectId'][idx]
    star_ccdid = sorted_sources['ccdVisitId'][idx]
    calexp_image = cutout_calexp(butler,
                                 star_ra,
                                 star_dec,
                                 star_visitid,
                                 star_detector,
                                 50)
    figout = make_calexp_fig(calexp_image,
                             batch_dir + "/images/" +
                             str(star_id) + "_" +
                             str(star_ccdid) + ".png")
    plt.show()
    remove_figure(figout)
    figout_data['location:image_'+str(i)] = str(star_id) + \
        "_" + str(star_ccdid) + ".png"
    figout_data['diaObjectId:image_'+str(i)] = str(star_id)
    figout_data['filename'] = str(star_id) + "_" + str(star_ccdid) + ".png"

df_manifest = pd.DataFrame(data=figout_data, index=[0])
outfile = batch_dir + "images/manifest.csv"
df_manifest.to_csv(outfile, index=False, sep=',')

Expected behavior WCS coordinate labels on the axes that align or at least that I can convince myself that align.

Actual behavior Looks like a coordinate mismatch even though they should in the same RA and dec in each calexp image.

My solution for now is to turn off the axes and send blank axes to Zooniverse. I'm not sure if this is more of a scientist problem or the responsibility of this team to figure out this WCS issue. I need to decide what to do about this before submitting a PR for the variable stars notebook because this will determine which plotting utilities we add to the pipeline.