HERA-Team / hera_cal

Library for HERA data reduction, including redundant calibration, absolute calibration, and LST-binning.
MIT License
13 stars 8 forks source link

BUG: test_vis_clean_dpss #955

Open steven-murray opened 1 month ago

steven-murray commented 1 month ago

Currently the test test_vis_clean_dpss randomly fails on CI about 1/4 of the time. Here's the error:

self = <hera_cal.tests.test_vis_clean.Test_VisClean object at 0x13572d7c0>

    @pytest.mark.filterwarnings("ignore:.*dspec.vis_filter will soon be deprecated")
    def test_vis_clean_dpss(self):
        # Relax atol=1e-6 for clean_data and data equalities. there may be some numerical
        # issues going on. Notebook tests show that distributing minus signs has
        # consequences.
        fname = os.path.join(DATA_PATH, "zen.2458043.40141.xx.HH.XRAA.uvh5")
        V = VisClean(fname, filetype='uvh5')
        V.read()

        # most coverage is in dspec. Check that args go through here.
        # similar situation for test_vis_clean.
        V.vis_clean(keys=[(24, 25, 'ee'), (24, 25, 'ee')], ax='freq', overwrite=True, mode='dpss_leastsq')
        # check that clean resid is equal to zero in flagged channels
        assert np.all(V.clean_resid[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_resid[(24, 25, 'ee')][~(V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')])] != 0.)
        assert np.all(V.clean_model[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_model[(24, 25, 'ee')][~V.clean_flags[(24, 25, 'ee')]] != 0.)
        # check that filtered_data is the same in channels that were not flagged
        atol = 1e-6 * np.mean(np.abs(V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')]]) ** 2.) ** .5
        assert np.all(np.isclose(V.clean_data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]],
                                 V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]], atol=atol, rtol=0.))
        assert np.all([V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_1'][i] == 'success' for i in V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_1']])

        assert pytest.raises(AssertionError, V.vis_clean, keys=[(24, 25, 'ee')], ax='time', mode='dpss_leastsq')
        assert pytest.raises(ValueError, V.vis_clean, keys=[(24, 25, 'ee')], ax='time', max_frate='arglebargle', mode='dpss_leastsq')

        # cover no overwrite = False skip lines.
        V.vis_clean(keys=[(24, 25, 'ee'), (24, 25, 'ee')], ax='freq', overwrite=False, mode='dpss_leastsq')

        V.vis_clean(keys=[(24, 25, 'ee'), (24, 25, 'ee')], ax='time', overwrite=True, max_frate=1.0, mode='dpss_leastsq')
        assert V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_0'][0] == 'skipped'
        assert V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_0'][3] == 'success'
        # check that clean resid is equal to zero in flagged channels
        assert np.all(V.clean_resid[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_resid[(24, 25, 'ee')][~(V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')])] != 0.)
        assert np.all(V.clean_model[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_model[(24, 25, 'ee')][~V.clean_flags[(24, 25, 'ee')]] != 0.)
        # check that filtered_data is the same in channels that were not flagged
        atol = 1e-6 * np.mean(np.abs(V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')]]) ** 2.) ** .5
        assert np.all(np.isclose(V.clean_data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]],
                                 V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]], atol=atol, rtol=0.))
        V.vis_clean(keys=[(24, 25, 'ee'), (24, 25, 'ee')], ax='both', overwrite=True, max_frate=1.0, mode='dpss_leastsq')
        assert np.all(['success' == V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_1'][i] for i in V.clean_info[(24, 25, 'ee')][(0, V.Nfreqs)]['status']['axis_1']])
        # check that clean resid is equal to zero in flagged channels
        assert np.all(V.clean_resid[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_resid[(24, 25, 'ee')][~(V.clean_flags[(24, 25, 'ee')] | V.flags[(24, 25, 'ee')])] != 0.)
        assert np.all(V.clean_model[(24, 25, 'ee')][V.clean_flags[(24, 25, 'ee')]] == 0.)
        assert np.any(V.clean_model[(24, 25, 'ee')][~V.clean_flags[(24, 25, 'ee')]] != 0.)
        # check that filtered_data is the same in channels that were not flagged
        atol = 1e-6 * np.mean(np.abs(V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')]]) ** 2.) ** .5
        assert np.all(np.isclose(V.clean_data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]],
                                 V.data[(24, 25, 'ee')][~V.flags[(24, 25, 'ee')] & ~V.clean_flags[(24, 25, 'ee')]], atol=atol, rtol=0.))
        # run with flag_model_rms_outliers
        for ax in ['freq', 'time', 'both']:
            for k in V.flags:
                V.flags[k][:] = False
                V.data[k][:] = np.random.randn(*V.data[k].shape) + 1j * np.random.randn(*V.data[k].shape)
            # run with rms threshold < 1 which should lead to everything being flagged.
            V.vis_clean(keys=[(24, 25, 'ee'), (24, 25, 'ee')], ax=ax, overwrite=True,
                        max_frate=1.0, mode='dpss_leastsq', flag_model_rms_outliers=True, model_rms_threshold=0.1)
            for k in [(24, 25, 'ee'), (24, 25, 'ee')]:
>               assert np.all(V.clean_flags[k])
E               assert False
E                +  where False = <function all at 0x1043a21b0>(array([[ True,  True,  True, ...,  True,  True,  True],\n       [ True,  True,  True, ...,  True,  True,  True],\n      ...True],\n       [ True,  True,  True, ...,  True,  True,  True],\n       [ True,  True,  True, ...,  True,  True,  True]]))
E                +    where <function all at 0x1043a21b0> = np.all

This should be fixed because it's very distracting having tests fail randomly.