rsagroup / rsatoolbox

Python library for Representational Similarity Analysis
MIT License
181 stars 38 forks source link

Error in _check_noise function within .rdm.calc.py when running crossnobis with residuals #337

Closed alexeperon closed 10 months ago

alexeperon commented 1 year ago

[rsatoolbox 0.1.3, python 3.9.16]

When calculating crossnobis distance, it is possible to input noise via a dictionary of residuals. However, currently this results in an recursion and timeout due to the _check_noise function.

This is because the elif clause checks first to see if the input type is iterable, then if it is a dictionary. As dictionaries are iterable, this then selects the wrong if clause.

current code (lines 465 to 489):

def _check_noise(noise, n_channel):
    """
    checks that a noise pattern is a matrix with correct dimension
    n_channel x n_channel

    Args:
        noise: noise input to be checked

    Returns:
        noise(np.ndarray): n_channel x n_channel noise precision matrix

    """
    if noise is None:
        pass
    elif isinstance(noise, np.ndarray) and noise.ndim == 2:
        assert np.all(noise.shape == (n_channel, n_channel))
    elif isinstance(noise, Iterable):
        for idx, noise_i in enumerate(noise):
            noise[idx] = _check_noise(noise_i, n_channel)
    elif isinstance(noise, dict):
        for key in noise.keys():
            noise[key] = _check_noise(noise[key], n_channel)
    else:
        raise ValueError('noise(s) must have shape n_channel x n_channel')
    return noise

To solve this, a simple option would be the following, which checks the input is not a dictionary. Another solution would be to simply put the dict elif clause before the Iterable elif clause.

def _check_noise(noise, n_channel):
    """
    checks that a noise pattern is a matrix with correct dimension
    n_channel x n_channel

    Args:
        noise: noise input to be checked

    Returns:
        noise(np.ndarray): n_channel x n_channel noise precision matrix

    """
    if noise is None:
        pass
    elif isinstance(noise, np.ndarray) and noise.ndim == 2:
        assert np.all(noise.shape == (n_channel, n_channel))
    elif isinstance(noise, Iterable) and not isinstance(noise, dict):
        for idx, noise_i in enumerate(noise):
            noise[idx] = _check_noise(noise_i, n_channel)
    elif isinstance(noise, dict):
        for key in noise.keys():
            noise[key] = _check_noise(noise[key], n_channel)
    else:
        raise ValueError('noise(s) must have shape n_channel x n_channel')
    return noise

[I'm not sure why it results in a recursion instead of a simple error - as I read it, this sends input of the form (int, n_channel) back to the _check_noise function, This should result in an ValueError, but maybe I am missing something about how embedded functions work.]

Code to reproduce the error (assuming data already calculated):


residuals = {} # create dict of 8 residuals, one for each run

ind_noise = 0
for run in runs:

    run_drive = glm_results_dir + sub + '/run-' + run + '/'
    res_file = f'{sub}_run-{run}_residuals.nii.gz'
    res = nib.load(run_drive+res_file)

    # apply the mask to the beta map
    masked_res = apply_mask(res, mask)

    # get the noise for the run and shrink
    noise_pres_res = rsatoolbox.data.noise.prec_from_residuals(masked_res, method='shrinkage_diag') # get an estimate of noise

    # append the masked beta map to the measurements array
    residuals[ind_noise] = noise_pres_res
    ind_noise += 1

rdm_cv = rsatoolbox.rdm.calc_rdm(data, method='crossnobis', descriptor='conds', cv_descriptor='sessions', noise=residuals)

Hope this is useful!

Alex

HeikoSchuett commented 1 year ago

Thanks for picking this error up @alexeperon! I am in favour of changing the order to fix this bug.

I just checked what enumerate actually does when you apply it to a dict and the resulting iterator just returns an integer and the corresponding key. Thus the noise input will be a string, which is an Iterable leading to infinite recursion.