casangi / xradio

Xarray Radio Astronomy Data IO
Other
9 stars 5 forks source link

casacore getcol() and getcolnp() incorrectly load data #152

Closed sstansill closed 2 months ago

sstansill commented 2 months ago

xradio is exposed to the open but stale issue https://github.com/casacore/python-casacore/issues/130. See also https://github.com/casangi/xradio/issues/151#issuecomment-2061246766. Team Pando in SKA came across this issue a while ago but never identified the cause.

sstansill commented 2 months ago

Two strategies to mitigate https://github.com/casacore/python-casacore/issues/130 within xradio have been tested and benchmarked. The functions below are drop in replacements for the line data = tb_tool.getcol(col) in the function read_col_conversion().

The first strategy is based on the following function https://github.com/ratt-ru/CubiCal/blob/282257714d2d414c254014bdd40e635919cb2064/cubical/data_handler/ms_data_handler.py#L831. The replacement function is

def get_data_max_rows(
        MeasurementSet,
        col: str,
    ):
        """
        Function to perform delayed reads from table columns when converting
        (no need for didxs)
        """

        # Maximum number of 64 bit elements to be loaded using getcol()
        # Workaround for https://github.com/casacore/python-casacore/issues/130
        max_elements = 2**29

        # Get the total number of rows in the base measurement set
        nrows_total = MeasurementSet.nrows()

        # Use casacore to get the shape of a row for this column
        # WARNING: Assumes MeasurementSet is a single measurement set not an MMS
        #################################################################################

        try:
            # WARNING: getcolshapestring() only works on columns where a row element is an array (ie breaks for TIME, etc)
            shape_string = MeasurementSet.getcolshapestring(col)[0]
            extra_dimensions = tuple([int(idx) for idx in shape_string.replace("[", "").replace("]", "").split(", ")])
            full_shape = tuple([nrows_total] + [int(idx) for idx in shape_string.replace("[", "").replace("]", "").split(", ")])
        except RuntimeError:
            extra_dimensions = ()
            full_shape = (nrows_total, )

        #################################################################################

        # Get dtype of the column. Only read first row from disk
        col_dtype = np.array(MeasurementSet.col(col)[0]).dtype

        # Construct the numpy array to populate with data
        data = np.empty(full_shape, dtype=col_dtype)

        # Compare the number of elements with the maximum allowed from:
        # https://github.com/casacore/python-casacore/issues/130#issuecomment-748150854

        num_elements = prod(full_shape)
        maxrows = max(1, max_elements // num_elements)

        # If below number of elements, we can safely use getcolnp()
        # Note don't use getcol() because it's less safe. See:
        # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373
        if maxrows > nrows_total:
            MeasurementSet.getcolnp(col, data)

        # Else there are more elements than is safe so construct an iterator to populate data
        else:
            # Iteratively populate the data column
            for start_row in range(0, nrows_total, maxrows):
                MeasurementSet.getcolnp(col, data[start_row:start_row+maxrows], start_row, maxrows)

        # Return the data from the column
        return data

The second is based on casacore tableiter objects (https://casacore.github.io/python-casacore/casacore_tables.html#casacore.tables.tableiter). In a conversation with @tammojan, this is the expected / best practise method of iterating over a measurement set. The replacement function is

def get_data_iterator(
    MeasurementSet,
    col: str,
):
    """
    Function to perform delayed reads from table columns when converting
    (no need for didxs)
    """

    # Workaround for https://github.com/casacore/python-casacore/issues/130

    # Use casacore to get the shape of a row for this column
    # WARNING: Assumes MeasurementSet is a single measurement set not an MMS
    #################################################################################

    # Get the total number of rows in the base measurement set
    nrows_total = MeasurementSet.nrows()

    try:
        # WARNING: getcolshapestring() only works on columns where a row element is an 
        # array (ie fails for TIME, etc)
        shape_string = MeasurementSet.getcolshapestring(col)[0]
        extra_dimensions = tuple([int(idx) for idx in shape_string.replace("[", "").replace("]", "").split(", ")])
        full_shape = tuple([nrows_total] + [int(idx) for idx in shape_string.replace("[", "").replace("]", "").split(", ")])
    except RuntimeError:
        extra_dimensions = ()
        full_shape = (nrows_total, )

    #################################################################################

    # Get dtype of the column. Only read first row from disk
    col_dtype = np.array(MeasurementSet.col(col)[0]).dtype

    # Construct the numpy array to populate with data
    data = np.empty(full_shape, dtype=col_dtype)

    # Use built-in casacore table iterator to populate the data column by unique times.
    # WARNING: Assumes the num_frequencies * num_polarisations > 2**29. If false,
    # https://github.com/casacore/python-casacore/issues/130 isn't mitigated
    start_row = 0
    for ts in MeasurementSet.iter("TIME", sort=False):
        num_rows = ts.nrows()
        # Note don't use getcol() because it's less safe. See:
        # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373
        ts.getcolnp(col, data[start_row:start_row+num_rows])
        start_row += num_rows

    return data

The latter is faster (see attached notebook). I'll commit a patch shortly.

xradio_152_fix_performance.ipynb.zip

sstansill commented 2 months ago

Manual regression testing was performed against getcol() method. See manual_regression_testing.ipynb.zip