darafferty / LSMTool

LOFAR Local Sky Model Tool
https://lsmtool.readthedocs.io
GNU General Public License v3.0
8 stars 5 forks source link

Performance bottleneck (also for Rapthor) #39

Open AlexKurek opened 2 weeks ago

AlexKurek commented 2 weeks ago

At the beginning of Rapthors run and when the catalogue is set to LoTTS:

    96215  110.803    0.001  110.803    0.001 {method 'copy' of 'numpy.ndarray' objects}
    13599   32.979    0.002  146.084    0.011 skymodel.py:1020(getRowIndex)
       68    3.832    0.056    3.832    0.056 table.py:1011(getcol)
    16759    2.844    0.000    3.590    0.000 wcs.py:410(__init__)
    81616    2.306    0.000    2.317    0.000 table.py:2226(keys)

In skymodels.py:

Line #      Hits         Time  Per Hit   % Time  Line Contents
  1056      3037    6509104.2   2143.3     29.3%          if self.hasPatches and rowName in self.getPatchNames():                
  1057      3037   15701311.1   5170.0     70.7%              return np.where(self.getColValues('Patch') == rowName)[0].tolist()

If this could be optimised, Rapthor would get a little faster. Method copyof numpy.ndarray is also probably triggered by skymodel.py:1020.

Consider exiting from the function earlier :

if not self.hasPatches:
  return []
if rowName not in self.getPatchNames():
  return []
return np.where(self.getColValues('Patch') == rowName)[0].tolist()

and @lru_cache(maxsize=None) for getPatchNames().

AlexKurek commented 1 week ago

I see that this function is changed (optimized?) in current master: https://github.com/darafferty/LSMTool/blob/master/lsmtool/skymodel.py#L1015 But strangely after installing Rapthor and upgrading LSMTool to master, Rapthor still calls skymodel.py which looks like this:

    def getRowIndex(self, rowName):
        """
        Returns index or indices for specified source or patch as a list.

        Parameters
        ----------
        rowName : str
            Name of the source or patch

        Returns
        -------
        indices : list
            List of indices. ValueError is raised if the source is not found.

        Examples
        --------
        Get row index for the source 'src1'::

            >>> s.getRowIndex('src1')
            [0]

        Get row indices for the patch 'bin1' and verify the patch name::

            >>> ind = s.getRowIndex('bin1')
            >>> print(s.getColValues('patch')[ind])
            ['bin1', 'bin1', 'bin1']

        """
        import numpy as np

        # Check first for the rowName as a patch name. If no patch matches (or
        # the model is not grouped into patches), Try it as a source name. This
        # logic should work even if a row has the same name for the source and
        # its patch, as in this case the patch and source row index are
        # identical (since such a patch can have only one member source)
        if self.hasPatches and rowName in self.getPatchNames():
            return np.where(self.getColValues('Patch') == rowName)[0].tolist()
        elif rowName in self.getColValues('Name'):
            return self._getNameIndx(rowName)
        else:
            raise ValueError("Row name '{0}' not recognized.".format(rowName))
AlexKurek commented 1 week ago

Also in operations_lib.py for radec2xy() consider:

# Assuming RA and Dec are NumPy arrays
ra_dec = np.vstack((RA, Dec)).T

# Perform the WCS transformation on the entire array at once
pix_coords = w.wcs_world2pix(ra_dec, 0)

# Extract x and y coordinates
x = pix_coords[:, 0]
y = pix_coords[:, 1]

instead of:

    for ra_deg, dec_deg in zip(RA, Dec):
        ra_dec = np.array([[ra_deg, dec_deg]])
        x.append(w.wcs_world2pix(ra_dec, 0)[0][0])
        y.append(w.wcs_world2pix(ra_dec, 0)[0][1])

This function is heavily used by Rapthor.

gmloose commented 1 week ago

Thanks @AlexKurek. I had already identified this as a bottle-neck, but hadn't had time to look into how to improve this.

AlexKurek commented 1 week ago

It seems to be much faster. Above is the looped version, below is the vectorised one:

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
30727  249.860    0.002 1558.249    0.012 /home/akurek/.local/lib/python3.10/site-packages/lsmtool/operations_lib.py:338(radec2xy)

130735    0.729    0.000   32.301    0.000 /home/akurek/.local/lib/python3.10/site-packages/lsmtool/operations_lib.py:338(radec2xy)

I hope I have not made a mistake, as this is extracted from a Rapthor run on 8h data. Anyway, it should be also tested if the results are the same.

AlexKurek commented 2 days ago

Another bottleneck is here:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  2425662  260.723    0.000  794.842    0.000 /home/akurek/.local/lib/python3.10/site-packages/lsmtool/operations/_meanshift.py:56(euclid_distance)

The way of calling maybe not optimal given frequent calls: Screenshot from 2024-06-29 08-51-05 Maybe not passing self as an argument would help, it seems not needed.

Vectorizing the loop in run(self) also would do it.