telegraphic / pygdsm

Python interface to Global Diffuse Sky Models (GDSM) for the radio sky between 10 MHz - 5 THz
MIT License
30 stars 7 forks source link

Linear interpolation in GSM2016 #14

Closed telegraphic closed 1 year ago

telegraphic commented 1 year ago

Note: Issue based on discussions with R. Braun

There are discontinuities in the slope of the GSM2016 spectra, as seen here at 45 MHz:

image (Fig 25, Price et al, 2018)

This kind of behaviour is quite problematic for foreground removal attempts if people are trying to get down to signal levels much fainter than the Galactic foreground.

Currently, the GSM2016 code is hard-wired for a linear interpolation in log(frequency) for every frequency that one wishes to generate. R. Braun suggests quadratic (or N>=2) interpolation instead.

telegraphic commented 1 year ago

Proof-of-concept code:

import numpy as np
import h5py
from astropy import units
import healpy as hp
import ephem

from .component_data import GSM2016_FILEPATH
from .plot_utils import show_plt
from .base_observer import BaseObserver
from .base_skymodel import BaseSkyModel

kB = 1.38065e-23
C = 2.99792e8
h = 6.62607e-34
T = 2.725
hoverk = h / kB

def K_CMB2MJysr(K_CMB, nu):#in Kelvin and Hz
    B_nu = 2 * (h * nu)* (nu / C)**2 / (np.exp(hoverk * nu / T) - 1)
    conversion_factor = (B_nu * C / nu / T)**2 / 2 * np.exp(hoverk * nu / T) / kB
    return  K_CMB * conversion_factor * 1e20#1e-26 for Jy and 1e6 for MJy

def K_RJ2MJysr(K_RJ, nu):#in Kelvin and Hz
    conversion_factor = 2 * (nu / C)**2 * kB
    return  K_RJ * conversion_factor * 1e20#1e-26 for Jy and 1e6 for MJy

def rotate_map(hmap, rot_theta, rot_phi, nest=True):
    nside = hp.npix2nside(len(hmap))

    # Get theta, phi for non-rotated map
    t, p = hp.pix2ang(nside, np.arange(hp.nside2npix(nside)), nest= nest)  # theta, phi

    # Define a rotator
    r = hp.Rotator(deg=False, rot=[rot_phi, rot_theta])

    # Get theta, phi under rotated co-ordinates
    trot, prot = r(t, p)

    # Inerpolate map onto these co-ordinates
    rot_map = hp.get_interp_val(hmap, trot, prot, nest= nest)

    return rot_map

class GlobalSkyModel2016(BaseSkyModel):
    """ Global sky model (GSM) class for generating sky models.
    """

    def __init__(self, freq_unit='MHz', data_unit='TCMB', resolution='hi', theta_rot=0, phi_rot=0):
        """ Global sky model (GSM) class for generating sky models.

        Upon initialization, the map PCA data are loaded into memory and interpolation
        functions are pre-computed.

        Parameters
        ----------
        freq_unit: 'Hz', 'MHz', or 'GHz'
            Unit of frequency. Defaults to 'MHz'.
        data_unit: 'MJysr', 'TCMB', 'TRJ'
            Unit of output data. MJy/Steradian, T_CMB in Kelvin, or T_RJ.
        resolution: 'hi' or 'low'
            Resolution of output map. Either 300 arcmin (low) or 24 arcmin (hi).
            For frequencies under 10 GHz, output is 48 arcmin.

        Notes
        -----

        """

        if data_unit not in ['MJysr', 'TCMB', 'TRJ']:
            raise RuntimeError("UNIT ERROR: %s not supported. Only MJysr, TCMB, TRJ are allowed." % data_unit)

        if resolution.lower() in ('hi', 'high', 'h'):
            resolution = 'hi'
        elif resolution.lower() in ('low', 'lo', 'l'):
            resolution = 'low'
        else:
            raise RuntimeError("RESOLUTION ERROR: Must be either hi or low, not %s" % resolution)

        super(GlobalSkyModel2016, self).__init__('GSM2016', GSM2016_FILEPATH, freq_unit, data_unit, basemap='')

        self.resolution = resolution

        # Map data to load
        labels = ['Synchrotron', 'CMB', 'HI', 'Dust1', 'Dust2', 'Free-Free']
        self.n_comp = len(labels)

        if resolution=='hi':
            self.nside = 1024
            self.map_ni = np.array([self.h5['highres_%s_map'%lb][:] for lb in labels])
        else:
            self.nside = 64
            self.map_ni = np.array(self.h5['lowres_maps'])

        self.spec_nf = self.h5['spectra'][:]

        if theta_rot or phi_rot:
            for i,map in enumerate(self.map_ni):
                self.map_ni[i] = rotate_map(map, theta_rot, phi_rot, nest=True)

    def generate(self, freqs, freq_min, freq_max):
        """ Generate a global sky model at a given frequency or frequencies

        Parameters
        ----------
        freqs: float or np.array
            Frequency for which to return GSM model

        Returns
        -------
        gsm: np.array
            Global sky model in healpix format, with NSIDE=1024. Output map
            is in galactic coordinates, ring format.

        """

        # convert frequency values into Hz
        freqs = np.array(freqs) * units.Unit(self.freq_unit)
        freq_min = np.array(freq_min) * units.Unit(self.freq_unit)
        freq_max = np.array(freq_max) * units.Unit(self.freq_unit)
        freqs_ghz = freqs.to('GHz').value
        freq_min_ghz = freq_min.to('GHz').value
        freq_max_ghz = freq_max.to('GHz').value

        if isinstance(freqs_ghz, float):
            freqs_ghz = np.array([freqs_ghz])

        try:
            assert np.min(freqs_ghz) >= 0.01
            assert np.max(freqs_ghz) <= 5000
        except AssertionError:
            raise RuntimeError("Frequency values lie outside 10 MHz < f < 5 THz: %s")

        map_ni = self.map_ni
        # if self.resolution == 'hi':
        #     map_ni = self.map_ni_hr
        # else:
        #     map_ni = self.map_ni_lr

        spec_nf = self.spec_nf
        nfreq = spec_nf.shape[1]

        output = np.zeros((len(freqs_ghz), map_ni.shape[1]), dtype='float32')
        for ifreq, freq in enumerate(freqs_ghz):

            left_index = -1
            for i in range(nfreq - 1):
                if spec_nf[0, i] <= freq <= spec_nf[0, i + 1]:
                    left_index = i
                    break

            gleft_index = -1
            for i in range(nfreq - 1):
                if spec_nf[0, i] <= freq_min_ghz <= spec_nf[0, i + 1]:
                    left_index = i
                    break

            gright_index = -1
            for i in range(1, nfreq):
                if spec_nf[0, i - 1] <= freq_max_ghz <= spec_nf[0, i]:
                    right_index = i
                    break

            # Do the interpolation
            # First determine the order of interpolation to use (only quadratic or linear in log(nu) supported)
            if ((gright_index - gleft_index) == 2):
                interp_spec_nf = np.copy(spec_nf)
                interp_spec_nf[0:2] = np.log10(interp_spec_nf[0:2])
                x0 = interp_spec_nf[0, gleft_index]
                x1 = interp_spec_nf[0, gleft_index + 1]
                x2 = interp_spec_nf[0, gleft_index + 2]
                y0 = interp_spec_nf[1:, gleft_index]
                y1 = interp_spec_nf[1:, gleft_index + 1]
                y2 = interp_spec_nf[1:, gleft_index + 2]
                x = np.log10(freq)
                L0 = ((x-x1)*(x-x2))/((x0-x1)*(x0-x2))
                L1 = ((x-x0)*(x-x2))/((x1-x0)*(x1-x2))
                L2 = ((x-x0)*(x-x1))/((x2-x0)*(x2-x1))
                interpolated_vals = y0*L0 + y1*L1 + y2*L2
                output[ifreq] = np.sum(10.**interpolated_vals[0] * (interpolated_vals[1:, None] * map_ni), axis=0)

                output[ifreq] = hp.pixelfunc.reorder(output[ifreq], n2r=True)

            else:
                interp_spec_nf = np.copy(spec_nf)
                interp_spec_nf[0:2] = np.log10(interp_spec_nf[0:2])
                x0 = interp_spec_nf[0, left_index]
                x1 = interp_spec_nf[0, left_index + 1]
                y0 = interp_spec_nf[1:, left_index]
                y1 = interp_spec_nf[1:, left_index + 1]
                x = np.log10(freq)
                interpolated_vals = (x * (y1 - y0) + x1 * y0 - x0 * y1) / (x1 - x0)
                output[ifreq] = np.sum(10.**interpolated_vals[0] * (interpolated_vals[1:, None] * map_ni), axis=0)

                output[ifreq] = hp.pixelfunc.reorder(output[ifreq], n2r=True)

            if self.data_unit == 'TCMB':
                conversion = 1. / K_CMB2MJysr(1., 1e9 * freq)
            elif self.data_unit == 'TRJ':
                conversion = 1. / K_RJ2MJysr(1., 1e9 * freq)
            else:
                conversion = 1.
            output[ifreq] *= conversion

#            output.append(result)

        if len(output) == 1:
            output = output[0]
        #else:
        #    map_data = np.row_stack(output)

        self.generated_map_freqs = freqs
        self.generated_map_data = output

        return output

class GSMObserver2016(BaseObserver):
    def __init__(self):
        """ Initialize the Observer object.

        Calls ephem.Observer.__init__ function and adds on gsm
        """
        super(GSMObserver2016, self).__init__(gsm=GlobalSkyModel2016)
telegraphic commented 1 year ago

Updated proof of concept code from R. Braun:

import numpy as np
from scipy.interpolate import interp1d, pchip
import h5py
from astropy import units
import healpy as hp
import ephem

from .component_data import GSM2016_FILEPATH
from .plot_utils import show_plt
from .base_observer import BaseObserver
from .base_skymodel import BaseSkyModel

kB = 1.38065e-23
C = 2.99792e8
h = 6.62607e-34
T = 2.725
hoverk = h / kB

def K_CMB2MJysr(K_CMB, nu):#in Kelvin and Hz
    B_nu = 2 * (h * nu)* (nu / C)**2 / (np.exp(hoverk * nu / T) - 1)
    conversion_factor = (B_nu * C / nu / T)**2 / 2 * np.exp(hoverk * nu / T) / kB
    return  K_CMB * conversion_factor * 1e20#1e-26 for Jy and 1e6 for MJy

def K_RJ2MJysr(K_RJ, nu):#in Kelvin and Hz
    conversion_factor = 2 * (nu / C)**2 * kB
    return  K_RJ * conversion_factor * 1e20#1e-26 for Jy and 1e6 for MJy

def rotate_map(hmap, rot_theta, rot_phi, nest=True):
    nside = hp.npix2nside(len(hmap))

    # Get theta, phi for non-rotated map
    t, p = hp.pix2ang(nside, np.arange(hp.nside2npix(nside)), nest= nest)  # theta, phi

    # Define a rotator
    r = hp.Rotator(deg=False, rot=[rot_phi, rot_theta])

    # Get theta, phi under rotated co-ordinates
    trot, prot = r(t, p)

    # Inerpolate map onto these co-ordinates
    rot_map = hp.get_interp_val(hmap, trot, prot, nest= nest)

    return rot_map

class GlobalSkyModel16(BaseSkyModel):
    """ Global sky model (GSM) class for generating sky models.
    """

    def __init__(self, freq_unit='MHz', data_unit='TCMB', resolution='hi', theta_rot=0, phi_rot=0, interpolation='pchip'):
        """ Global sky model (GSM) class for generating sky models.

        Upon initialization, the map PCA data are loaded into memory and interpolation
        functions are pre-computed.

        Parameters
        ----------
        freq_unit: 'Hz', 'MHz', or 'GHz'
            Unit of frequency. Defaults to 'MHz'.
        data_unit: 'MJysr', 'TCMB', 'TRJ'
            Unit of output data. MJy/Steradian, T_CMB in Kelvin, or T_RJ.
        resolution: 'hi' or 'low'
            Resolution of output map. Either 300 arcmin (low) or 24 arcmin (hi).
            For frequencies under 10 GHz, output is 48 arcmin.
        interpolation: 'cubic' or 'pchip'
            Choose whether to use cubic spline interpolation or
            piecewise cubic hermitian interpolating polynomial (PCHIP).
            PCHIP is designed to never locally overshoot data, whereas
            splines are designed to have smooth first and second derivatives.

        Notes
        -----
        The scipy `interp1d` function does not allow one to explicitly
        set second derivatives to zero at the endpoints, as is done in
        the original GSM. As such, results will differ. Further, we default
        to use PCHIP interpolation.

        """

        if data_unit not in ['MJysr', 'TCMB', 'TRJ']:
            raise RuntimeError("UNIT ERROR: %s not supported. Only MJysr, TCMB, TRJ are allowed." % data_unit)

        if resolution.lower() in ('hi', 'high', 'h'):
            resolution = 'hi'
        elif resolution.lower() in ('low', 'lo', 'l'):
            resolution = 'low'
        else:
            raise RuntimeError("RESOLUTION ERROR: Must be either hi or low, not %s" % resolution)

        super(GlobalSkyModel16, self).__init__('GSM2016', GSM2016_FILEPATH, freq_unit, data_unit, basemap='')

        self.interpolation_method = interpolation
        self.resolution = resolution

        # Map data to load
        labels = ['Synchrotron', 'CMB', 'HI', 'Dust1', 'Dust2', 'Free-Free']
        self.n_comp = len(labels)

        if resolution=='hi':
            self.nside = 1024
            self.map_ni = np.array([self.h5['highres_%s_map'%lb][:] for lb in labels])
        else:
            self.nside = 64
            self.map_ni = np.array(self.h5['lowres_maps'])

        self.spec_nf = self.h5['spectra'][:]

        if theta_rot or phi_rot:
            for i,map in enumerate(self.map_ni):
                self.map_ni[i] = rotate_map(map, theta_rot, phi_rot, nest=True)

    def generate(self, freqs):
        """ Generate a global sky model at a given frequency or frequencies

        Parameters
        ----------
        freqs: float or np.array
            Frequency for which to return GSM model

        Returns
        -------
        gsm: np.array
            Global sky model in healpix format, with NSIDE=1024. Output map
            is in galactic coordinates, ring format.

        """

        # convert frequency values into Hz
        freqs = np.array(freqs) * units.Unit(self.freq_unit)
        freqs_ghz = freqs.to('GHz').value

        if isinstance(freqs_ghz, float):
            freqs_ghz = np.array([freqs_ghz])

        try:
            assert np.min(freqs_ghz) >= 0.01
            assert np.max(freqs_ghz) <= 5000
        except AssertionError:
            raise RuntimeError("Frequency values lie outside 10 MHz < f < 5 THz: %s")

        map_ni = self.map_ni
        # if self.resolution == 'hi':
        #     map_ni = self.map_ni_hr
        # else:
        #     map_ni = self.map_ni_lr

        spec_nf = self.spec_nf
        nfreq = spec_nf.shape[1]

        # Now borrow code from the orignal GSM2008 model to do a sensible interpolation

        pca_freqs_ghz = spec_nf[0]
        pca_scaling   = spec_nf[1]
        pca_comps     = spec_nf[2:]
         # Interpolate to the desired frequency values
        ln_pca_freqs = np.log(pca_freqs_ghz)
        if self.interpolation_method == 'cubic':
            spl_scaling = interp1d(ln_pca_freqs, np.log(pca_scaling), kind='cubic')
            spl1 = interp1d(ln_pca_freqs,   pca_comps[0],   kind='cubic')
            spl2 = interp1d(ln_pca_freqs,   pca_comps[1],   kind='cubic')
            spl3 = interp1d(ln_pca_freqs,   pca_comps[2],   kind='cubic')
            spl4 = interp1d(ln_pca_freqs,   pca_comps[3],   kind='cubic')
            spl5 = interp1d(ln_pca_freqs,   pca_comps[4],   kind='cubic')
            spl6 = interp1d(ln_pca_freqs,   pca_comps[5],   kind='cubic')

        else:
            spl_scaling = pchip(ln_pca_freqs, np.log(pca_scaling))
            spl1 = pchip(ln_pca_freqs,   pca_comps[0])
            spl2 = pchip(ln_pca_freqs,   pca_comps[1])
            spl3 = pchip(ln_pca_freqs,   pca_comps[2])
            spl4 = pchip(ln_pca_freqs,   pca_comps[3])
            spl5 = pchip(ln_pca_freqs,   pca_comps[4])
            spl6 = pchip(ln_pca_freqs,   pca_comps[5])

        self.interp_comps = (spl_scaling, spl1, spl2, spl3, spl4, spl5, spl6)

        ln_freqs = np.log(freqs_ghz)
        comps = np.row_stack((spl1(ln_freqs), spl2(ln_freqs), spl3(ln_freqs), spl4(ln_freqs), spl5(ln_freqs), spl6(ln_freqs)))
        scaling = np.exp(spl_scaling(ln_freqs))

        # Finally, compute the dot product via einsum (awesome function)
        # c=comp, f=freq, p=pixel. We want to dot product over c for each freq
        #print comps.shape, self.pca_map_data.shape, scaling.shape

        output = np.single(np.einsum('cf,pc,f->fp', comps, map_ni.T, scaling))

        for ifreq, freq in enumerate(freqs_ghz):

            output[ifreq] = hp.pixelfunc.reorder(output[ifreq], n2r=True)

            if self.data_unit == 'TCMB':
                conversion = 1. / K_CMB2MJysr(1., 1e9 * freq)
            elif self.data_unit == 'TRJ':
                conversion = 1. / K_RJ2MJysr(1., 1e9 * freq)
            else:
                conversion = 1.
            output[ifreq] *= conversion

#            output.append(result)

        if len(output) == 1:
            output = output[0]
        #else:
        #    map_data = np.row_stack(output)

        self.generated_map_freqs = freqs
        self.generated_map_data = output

        return output

class GSMObserver16(BaseObserver):
    def __init__(self):
        """ Initialize the Observer object.

        Calls ephem.Observer.__init__ function and adds on gsm
        """
        super(GSMObserver16, self).__init__(gsm=GlobalSkyModel16)
telegraphic commented 1 year ago

I've just tested this code, and it does get rid of the discontinuity:

def test_interp():
    f = np.arange(40, 80, 5)
    for interp in ('pchip', 'cubic'):
        for SkyModel in (GlobalSkyModel, GlobalSkyModel16):
            name = str(SkyModel).strip("<>").split('.')[-1].strip("' ")
            gsm = SkyModel(freq_unit='MHz', interpolation=interp)
            d = gsm.generate(f)

            sky_spec = d.mean(axis=1)
            fit = np.poly1d(np.polyfit(f, sky_spec, 5))(f)

            plt.plot(f, sky_spec - fit, label=f'{name}: {interp}')

    plt.xlabel("Frequency [MHz]")
    plt.ylabel("Residual [K]")
    plt.legend()
    plt.show()

image

Here, I am plotting the residuals after fitting a 5th order polynomial.

Curious that the cubic interpolation agrees between sky models, but the PCHIP gives larger residuals. I don't recall why I went out of my way to use PCHIP interpolation in the first place 🤔 .