dstndstn / tractor

The Tractor: measuring astronomical sources via probabilistic inference
Other
86 stars 24 forks source link

Sersic fit yields significant bias in flux #108

Open hbahk opened 1 month ago

hbahk commented 1 month ago

Hello,

I'm struggling with the tractor to photometer Galsim-simulated galaxy profiles. I found that the tractor fit typically underestimates given fluxes, and the Gaussian mixture model flux for a Sersic profile is generally higher than the one generated from Galsim with the same total flux. This bias seems to depend on shear, size of the PSF, size of the source, sersic index, etc.

Below are some of results, this is the profile of the Galsim and tractor model with Gaussian PSF of sigma = 5 pix.

drawing

If this model is optimized then it gives slightly lower flux.

drawing

If I apply some shear, then the flux of the tractor mod gets boosted and the optimized flux gives underestimated value.

drawing drawing

This bias gets even worse for smaller PSF (I'm handling images of which PSF width < 1 pix)

drawing drawing

This behavior is also depends on sersic index and half light radius.

drawing

Freezing all shape parameters and thawing only brightness still show this bias..

drawing

I checked the legacy survey photometry and COSMOS2020 The Farmer catalog, but in these catalog this bias seems to be somehow corrected, or not existed at the first place.

drawing drawing

Did you recognize this flux bias? Then how did you resolve this for the legacy survey photometry? If not, did I make a mistake?

Thank you for your amazing work on the tractor by the way.

dstndstn commented 1 month ago

Hi,

If your PSF is not well sampled, then the Tractor code is definitely not guaranteed to produce correct results.

I believe that if you make your model image large enough, then the total flux in the Tractor model should equal your model flux.

hbahk commented 1 month ago

Thank you for your comment!

Based on your comment I checked the PSF sampling, but the bias seems to persist.

For Galsim simulation, I changed the sampling of the PSF and found no significant difference. I also tried simulating in x10 oversampled image and downsampling the image, but bias was similar.

For Tractor, skimming through the source code I thought that NCircularGaussianPSF should be analytically convolved with the gaussian mixture model of Sersic profile. So I think it should not affect on the bias in my simulated sample.

Changing the window size doesn't help either.

drawing

Maybe we need more gaussian components for modeling high-indexed sersic profiles, since the summed flux in image indicates that the Tractor model lacks some flux than the Galsim model (~5%p in this image, but the truncated flux of Galsim model should exist). It is hard to understand how the shear transformation causes larger bias too. Or I just handled either Galsim or Tractor wrong...

Below is my code to simulate and model sersic profiles for reference. Thank you!

# Define the function to process each (r_e, n) combination
def _process_r_e_n(
    r_e,
    n,
    sigma=1.5,
    draw_figure=False,
    use_gaussian_psf=False,
    fix_sersic=False,
    verbose=False,
    nx=32,
    ny=32,
    e1=0.42,
    e2=0.07,
    flux=100.0,
    dlnplim=1e-3,
    fix_all_but_flux=False,
    no_opti=False,
):
    pixel_scale = 0.2
    pixnoise = 0.002
    try:
        # Galsim simulation ================================
        # Create the galaxy profile
        bdgal = galsim.Sersic(
            half_light_radius=r_e, n=n, flux=flux, flux_untruncated=False
        )
        g = np.hypot(e1, e2)
        beta = -0.5 * (np.arctan2(e2, e1) - np.pi) * galsim.radians
        bdgal = bdgal.shear(g=g, beta=beta)

        # Convolve with PSF
        if use_gaussian_psf:
            psf_sigma = sigma * pixel_scale  # in arcsec
            sigma_x10 = 10 * sigma
            xx, yy = np.meshgrid(
                np.arange(-5 * sigma_x10, 5 * sigma_x10 + 1),
                np.arange(-5 * sigma_x10, 5 * sigma_x10 + 1),
            )
            psf_image = np.exp(-0.5 * (xx**2 + yy**2) / sigma_x10**2)
            _psf = galsim.InterpolatedImage(
                galsim.Image(psf_image, scale=pixel_scale / 10.0), flux=1.0
            )
            # _psf = galsim.Gaussian(flux=1.0, sigma=psf_sigma)

        bdfinal = galsim.Convolve([bdgal, _psf])
        if verbose:
            print(
                f"total flux = {bdfinal.flux}\ngalaxy flux = {bdgal.flux}\n" + 
                f"hlr = {bdgal.original.half_light_radius}\ninput hlr = {r_e}"
            )

        # Draw the image
        seed = int((r_e * 1000) + (n * 1000)) % 2**32
        rng = galsim.BaseDeviate(seed)
        gaussian_noise = galsim.GaussianNoise(rng, sigma=pixnoise)
        # img = galsim.Image(nx, ny, scale=pixel_scale)
        img = galsim.Image(nx*10, ny*10, scale=pixel_scale / 10.0)
        bdfinal.drawImage(image=img)
        dsimg = downscale_local_mean(img.array, (10, 10)) * 100
        img = galsim.Image(dsimg, scale=pixel_scale)

        # Add noise
        newImg = img.copy()
        newImg.addNoise(gaussian_noise)

        # Tractor modeling ================================
        # Prepare for Tractor fitting
        if use_gaussian_psf:
            tractor_psf = NCircularGaussianPSF([sigma], [1.0])

        tim = Image(
            data=newImg.array,
            inverr=np.ones_like(newImg.array) / pixnoise,
            photocal=LinearPhotoCal(1.0),
            wcs=NullWCS(pixscale=pixel_scale),
            psf=tractor_psf,
        )

        fluxinit = flux if no_opti else flux * 0.5

        if use_gaussian_psf:
            galaxy_class = SersicGalaxy
        else:
            galaxy_class = SPHERExTractorSersicGalaxy
        galaxy = galaxy_class(
            PixPos(nx / 2 - 0.5, ny / 2 - 0.5),
            Flux(fluxinit),
            # EllipseE(r_e, e1, e2),
            EllipseESoft(np.log(r_e), e1, e2),
            # EllipseESoft(0.0, 0.0, 0.0),
            SersicIndex(n),
        )
        if verbose:
            print(f"initial model flux = {galaxy.getBrightness().getValue()}")

        from tractor.constrained_optimizer import ConstrainedOptimizer

        tractor = Tractor([tim], [galaxy], optimizer=ConstrainedOptimizer())
        tractor.freezeParam("images")
        if fix_sersic:
            galaxy.freezeParam("sersicindex")
        if fix_all_but_flux:
            galaxy.freezeAllBut("brightness")

        # Optimize the model
        # for _ in range(20):
        #     dlnp, X, alpha, var = tractor.optimize(
        #         shared_params=False, variance=True
        #     )
        #     if dlnp < dlnplim:
        #         if verbose:
        #             print(f"Converged for r_e={r_e}, n={n} at dlnP={dlnp}, iter={_}")
        #         break

        if not no_opti:
            tractor.optimize_loop(shared_params=False)
            var = tractor.optimize(
                shared_params=False, variance=True, just_variance=True
            )

            fluxresult = galaxy.getBrightness().getValue()
            fluxerrid = np.array(galaxy.getParamNames()) == "brightness.Flux"
            fluxerror = np.sqrt(var[fluxerrid][0])
        else:
            fluxresult = galaxy.getBrightness().getValue()
            fluxerror = np.nan

        if verbose:
            print(galaxy.getStepSizes())

        if draw_figure:
            mod = tractor.getModelImage(0)
            fig = plt.figure(figsize=(8, 3))
            ax = fig.add_subplot(131)
            im = ax.imshow(newImg.array, cmap="gray", origin="lower")
            # ax.plot(nx/2-0.5, ny/2-0.5, "r+")
            vmin, vmax = im.get_clim()
            ax.set_title("Image")
            ax.text(
                0.05,
                0.05,
                r"$F_{\rm input}$" + f" = {flux:.2f}",
                transform=ax.transAxes,
                c="w",
                ha="left",
                va="bottom",
            )
            ax = fig.add_subplot(132)
            ax.imshow(mod, cmap="gray", origin="lower")
            ax.text(
                0.05,
                0.05,
                r"$n_{\rm mod}$"
                + f"={galaxy.sersicindex.getValue():.2f}\n"
                + f"$r_e$={galaxy.shape.re/pixel_scale:.2f} pix \n"
                + r"$F_{\rm mod}$"
                + f"={fluxresult:.2f}",
                transform=ax.transAxes,
                c="w",
                ha="left",
                va="bottom",
            )
            ax.text(
                0.95,
                0.05,
                f"e1={galaxy.shape.ee1:.2f}\ne2={galaxy.shape.ee2:.2f}",
                transform=ax.transAxes,
                c="w",
                ha="right",
                va="bottom",
            )
            ax.set_title("Model")
            ax = fig.add_subplot(133)
            ax.imshow(newImg.array - mod, cmap="gray", origin="lower")
            ax.set_title("Residual")
            fig.suptitle(f"Sersic Index: {n:.1f}, $r_e$: {r_e/pixel_scale:.1f} pixels")

            rr = ((np.arange(ny) - ny // 2) / r_e * pixel_scale) ** 0.25
            gs = newImg.array[ny // 2, :]
            md = mod[ny // 2, :]
            fig = plt.figure(figsize=(5, 3))
            ax = fig.add_subplot(111)
            ax.plot(rr, gs, label="Galsim")
            ax.plot(rr, md, label="Tractor")
            ax.legend()
            ax.text(
                0.05,
                0.05,
                f"Galsim sum: {newImg.array.sum():.2f}\n"
                + f"Tractor sum: {mod.sum():.2f}\n"
                + r"$F_{\rm input}$"
                + f"={flux:.2f}\n"
                + r"$F_{\rm mod}$"
                + f"={fluxresult:.2f}\n"
                + r"$r_{e,{\rm mod}}$"
                + f"={galaxy.shape.re/pixel_scale:.2f} pix \n"
                + r"$n_{\rm mod}$"
                + f"={galaxy.sersicindex.getValue():.2f}\n"
                + f"e1={galaxy.shape.ee1:.2f}\ne2={galaxy.shape.ee2:.2f}",
                transform=ax.transAxes,
                c="k",
                ha="left",
                va="bottom",
            )
            ax.set_title(r"$\sigma_{\rm PSF}$" + f"={sigma:.1f} pix")
            ax.set_xlabel("$(r/r_e)^{1/4}$")
            ax.set_ylabel("Flux")
            # ax.set_xscale('log')
            ax.set_yscale("log")
            ax.set_xlim(0.0, 0.5)
            ax.set_ylim(1e-3, 30)

        return (galaxy.shape.re, galaxy.sersicindex.getValue(), fluxresult, fluxerror)

    except Exception as e:
        if verbose:
            print(f"Error processing r_e={r_e}, n={n}: {e}")
        # fluxresult = np.nan
        # fluxerror = np.nan
        return (np.nan, np.nan, np.nan, np.nan)

# Run single combination
_process_r_e_n(
    1,            # half light radius in pixels
    4,            # sersic index
    0.5,          # sigma of PSF in pixels
    e1=0.2,       # shear e1
    e2=0.40,      # shear e2
    flux=10000.0, # total flux
    nx=100,       # window size x
    ny=100,       # window size y
    draw_figure=True, 
    use_gaussian_psf=True,
    fix_sersic=False,
    verbose=True,
    dlnplim=1e-6,
    fix_all_but_flux=False,
    no_opti=False,
)
hbahk commented 2 weeks ago

Hi,

I found that, when normalizing the MoG amplitudes, subtracting the component beyond the MoG model ($L_{\rm out}$) alleviates the flux bias. I think by this, the flux bias can be corrected by only a few percent level for galaxies in DECam-like images.

drawing

$$L{\rm out} = 0.5 - 2\pi\int{re}^{\infty} {\rm MoG}(r)rdr \equiv \texttt{beyond}$$ $$L{\rm in} = 0.5 - 2\pi\int_0^{r_e} {\rm MoG}(r)rdr \equiv \texttt{core}$$

But for $re/\sigma{\rm PSF} >> 1$ and $\sigma_{\rm PSF} \sim 1$ pix then it gives still significant flux bias for high sersic indices. I think this is inevitable with the MoG with given numbers of Gaussian components. So in this case, user should add their own models of MoG with more Gaussian components...

Could I get some advise or references on determining MoG coefficients for adding Gaussian models on current MoG models?

l used code below..

from tractor.sersic import SersicMixture, SersicGalaxy
from tractor import mixture_profiles as mp
from scipy.interpolate import InterpolatedUnivariateSpline

class SersicMixtureCorrected(SersicMixture):
    singleton = None

    @staticmethod
    def getProfile(sindex):
        if SersicMixtureCorrected.singleton is None:
            SersicMixtureCorrected.singleton = SersicMixtureCorrected()
        return SersicMixtureCorrected.singleton._getProfile(sindex)

    def __init__(self):
        super().__init__()

        self.beyonds = [
            (0.29, -0.00844581249119647),
            (0.3, -0.007543589678601026),
            (0.31, -0.006720049694663777),
            (0.32, -0.0059662304247363185),
            (0.33, -0.005277313728958344),
            (0.34, -0.0046521468417414225),
            (0.35, -0.004095346709762526),
            (0.36, -0.0035988688279965375),
            (0.37, -0.0031498118023494115),
            (0.38, -0.0027583846582809324),
            (0.39, -0.002435351519089801),
            (0.4, -0.002157335977535313),
            (0.41, -0.0019076077741795316),
            (0.42, -0.0016991406639236262),
            (0.43, -0.001425029716449977),
            (0.44, -0.0012650990437518272),
            (0.45, -0.0011347785247703968),
            (0.46, -0.0009879528471214982),
            (0.47, -0.0008100151316562387),
            (0.48, -0.0005870125324175524),
            (0.49, -0.0003190764835062643),
            (0.5, 5.425903805811316e-06),
            (0.51, 0.00034705152982811294),
            (0.515, 0.0002996544861531003),
            (0.52, 0.00022447215877752225),
            (0.53, 0.00010594007072395328),
            (0.54, 0.0002694537629742144),
            (0.55, 0.0004300827684946551),
            (0.56, 0.0005944090931850887),
            (0.57, 0.0015975318101883462),
            (0.575, 0.0007583228570509637),
            (0.58, 0.0006995684031628757),
            (0.6, 0.0007162051886123733),
            (0.62, 0.0010120879777582026),
            (0.63, 0.0011267421968184088),
            (0.64, 0.0012801213226446007),
            (0.65, 0.001451646001103979),
            (0.7, 0.002508762419644761),
            (0.71, 0.002766796388286308),
            (0.72, 0.003023229882584244),
            (0.73, 0.0032915776909705485),
            (0.74, 0.003561967457950399),
            (0.75, 0.0038673532336104266),
            (0.8, 0.005528522125680335),
            (0.85, 0.007468553253350441),
            (0.9, 0.00965566942039231),
            (0.95, 0.01205320678893651),
            (1.0, 0.014659549792626902),
            (1.1, 0.017005210259315284),
            (1.2, 0.01927687860355387),
            (1.3, 0.021484300887706198),
            (1.4, 0.02363493914344006),
            (1.5, 0.025774453598841895),
            (1.55, 0.02653204147973942),
            (1.6, 0.027731530277451344),
            (1.7, 0.02997295724425736),
            (1.8, 0.031100832053232386),
            (1.9, 0.03353732320270597),
            (2.0, 0.03415553715599429),
            (2.1, 0.035866666758461396),
            (2.3, 0.03925398489658294),
            (2.5, 0.04261982372191825),
            (2.7, 0.045970050960330966),
            (3.0, 0.05097239750541088),
            (3.1, 0.05244179457571285),
            (3.2, 0.05267681959158704),
            (3.3, 0.051008278225892156),
            (3.4, 0.051581162287590465),
            (3.5, 0.052756205813684454),
            (4.0, 0.05858779172595929),
            (4.5, 0.07168299370139825),
            (5.0, 0.0841544214197475),
            (5.5, 0.09593308483544921),
            (6.0, 0.10697422593718764),
            (6.1, 0.10906657469929443),
            (6.2, 0.11117858609339415),
            (6.3, 0.1132645336487178),
        ]

        self.cores = [
            (0.29, -0.0006145669842789747),
            (0.3, -0.00047558548715020965),
            (0.31, -0.000363217224507717),
            (0.32, -0.0002860576454500885),
            (0.33, -0.0002047657848697204),
            (0.34, -0.00014498433184317872),
            (0.35, -0.00010651925123739137),
            (0.36, -8.396446125114032e-05),
            (0.37, -4.507138441234293e-05),
            (0.38, -3.380973759070649e-05),
            (0.39, -2.060885174259841e-05),
            (0.4, -1.7333510825889853e-05),
            (0.41, -1.3655362836484386e-05),
            (0.42, -2.2053831621571263e-05),
            (0.43, 3.519140248531283e-05),
            (0.44, 5.8527269772845614e-05),
            (0.45, 6.389492410741049e-05),
            (0.46, 5.951319146363376e-05),
            (0.47, 4.41916821859456e-05),
            (0.48, 2.961055933942136e-05),
            (0.49, 1.2627912384211015e-05),
            (0.5, 1.212783262150019e-07),
            (0.51, -1.9430637892225988e-05),
            (0.515, -4.212502048273059e-05),
            (0.52, -3.228734261651045e-05),
            (0.53, 7.81361399720959e-06),
            (0.54, -4.824476926845733e-07),
            (0.55, -7.352795657555866e-06),
            (0.56, -1.4297605179125483e-05),
            (0.57, -0.0008618049018462859),
            (0.575, -3.926782323415701e-06),
            (0.58, 1.862366342647581e-05),
            (0.6, 2.3501590432239983e-05),
            (0.62, 3.091940991745146e-05),
            (0.63, 1.3872692724903324e-05),
            (0.64, 8.008823279725963e-06),
            (0.65, 5.0818015268627725e-06),
            (0.7, 1.0753942212227141e-05),
            (0.71, 1.3588086665461407e-05),
            (0.72, 1.233305641651361e-05),
            (0.73, 1.1090788702317056e-05),
            (0.74, 8.221140030351126e-06),
            (0.75, 8.823047828843134e-06),
            (0.8, 1.2454565203434687e-05),
            (0.85, 1.5852421938133965e-05),
            (0.9, 2.027803646037496e-05),
            (0.95, 2.3860516997376013e-05),
            (1.0, 3.39957344719366e-05),
            (1.1, 5.54545242191784e-05),
            (1.2, 8.690513039832926e-05),
            (1.3, 0.0001277037233881062),
            (1.4, 0.0001860448049312291),
            (1.5, 0.0002658230042170695),
            (1.55, 0.00021350097227901266),
            (1.6, 0.00031839967024277493),
            (1.7, 0.0004857433537968636),
            (1.8, 0.00037686920762880494),
            (1.9, 0.0006261207533730384),
            (2.0, 0.0005666093530467542),
            (2.1, 0.0007145367012874604),
            (2.3, 0.0010893594944085816),
            (2.5, 0.0015779745432438763),
            (2.7, 0.00218967829239175),
            (3.0, 0.0033474709444939466),
            (3.1, 0.0037814734649312953),
            (3.2, 0.004189334066699302),
            (3.3, 0.0048892121748450035),
            (3.4, 0.00551991679573588),
            (3.5, 0.0060998023127743495),
            (4.0, 0.009448723687390248),
            (4.5, 0.012080413858493733),
            (5.0, 0.01499344940562819),
            (5.5, 0.018145661429931625),
            (6.0, 0.021495074522919932),
            (6.1, 0.022184349060213604),
            (6.2, 0.02288044574498377),
            (6.3, 0.02358058762144588),
        ]

        self.core_func = InterpolatedUnivariateSpline(
            [s for s, c in self.cores], [c for s, c in self.cores], k=3
        )
        self.beyond_func = InterpolatedUnivariateSpline(
            [s for s, b in self.beyonds], [b for s, b in self.beyonds], k=3
        )

    def _getProfile(self, sindex):
        matches = []
        # clamp
        if sindex <= self.lowest:
            matches.append(self.fits[0])
            # (lo,hi,a,v) = self.fits[0]
            # amp_funcs = a
            # logvar_funcs = v
            sindex = self.lowest
        elif sindex >= self.highest:
            matches.append(self.fits[-1])
            # (lo,hi,a,v) = self.fits[-1]
            # amp_funcs = a
            # logvar_funcs = v
            sindex = self.highest
        else:
            for f in self.fits:
                lo, hi, a, v = f
                if sindex >= lo and sindex < hi:
                    matches.append(f)
                    # print('Sersic index', sindex, '-> range', lo, hi)
                    # amp_funcs = a
                    # logvar_funcs = v
                    # break

        if len(matches) == 2:
            # Two ranges overlap.  Ramp between them.
            # Assume self.fits is ordered in increasing Sersic index
            m0, m1 = matches
            lo0, hi0, a0, v0 = m0
            lo1, hi1, a1, v1 = m1
            assert lo0 < lo1
            assert lo1 < hi0  # overlap is in here
            ramp_lo = lo1
            ramp_hi = hi0
            assert ramp_lo < ramp_hi
            assert ramp_lo <= sindex
            assert sindex < ramp_hi
            ramp_frac = (sindex - ramp_lo) / (ramp_hi - ramp_lo)
            # print('Sersic index', sindex, ': ramping between ranges', (lo0,hi0), 'and', (lo1,hi1), '; ramp', (ramp_lo, ramp_hi), 'fraction', ramp_frac)
            amps0 = np.array([f(sindex) for f in a0])
            amps0 /= amps0.sum()
            amps1 = np.array([f(sindex) for f in a1])
            amps1 /= amps1.sum()
            amps = np.append((1.0 - ramp_frac) * amps0, ramp_frac * amps1)
            varr = np.exp(np.array([f(sindex) for f in v0 + v1]))
        else:
            assert len(matches) == 1
            lo, hi, amp_funcs, logvar_funcs = matches[0]
            amps = np.array([f(sindex) for f in amp_funcs])
            amps /= amps.sum()
            varr = np.exp(np.array([f(sindex) for f in logvar_funcs]))

        # Core
        core = self.core_func(sindex)
        beyond = self.beyond_func(sindex)
        amps *= (1.0 - core - beyond) / amps.sum()
        # amps *= (1.0 - beyond) / amps.sum()
        # amps *= (1.0) / amps.sum()
        amps = np.append(amps, core)
        varr = np.append(varr, 0.0)

        return mp.MixtureOfGaussians(amps, np.zeros((len(amps), 2)), varr)

class SersicGalaxyCorrected(SersicGalaxy):
    def getProfile(self):
        return SersicMixtureCorrected.getProfile(self.sersicindex.val)