rsagroup / rsatoolbox

Python library for Representational Similarity Analysis
MIT License
188 stars 37 forks source link

Add noise input to get_searchlight_RDMs function for crossnobis and mahalanobis distance measures #356

Open alexeperon opened 1 year ago

alexeperon commented 1 year ago

Hi all,

Currently there is no simple way to input residuals (or noise in general) to use for noise calculations into the get_searchlight_RDMs function. This means there is no way to use crossnobis or mahalanobis distance with the current implementation of searchlight.

I would suggest maybe adding the arguments "method_cov"=None and "residuals"=None to the function.

The function should first select relevant channels from both the data and residuals (if present) for a given searchlight centre. Then, if 'residuals' are present, then the function should use either rsatoolbox.data.noise.prec_from_residuals. If only 'method_cov' is present, it should use rsatoolbox.data.noise.prec_from_measurements to calculate the noise covariance matrix. This should use the input from "method_cov".

The resulting covariance matrix can then be input into the calc_rdm function in the argument "noise".

Let me know how this sounds - I'm happy to have a crack at implementing it if it's useful.

Best wishes, Alex

JasperVanDenBosch commented 1 year ago

Hi Alex,

Thanks for filing this, it makes a lot of sense. Indeed a rework of searchlight as well as how we pass around noise is something we've had on our list for a while, but probably won't get around to before the end of the year.

A PR would definitely be welcome, let me know if you have any questions / want to discuss anything.

alexeperon commented 1 year ago

Thanks Jasper. I've written a pretty hacky function which handles crossnobis for searchlight, which I'll share here. It's not hugely efficient as there's no chunking and I'm not confident enough in my coding to submit it a a PR, but maybe it will come in handy for someone else in the same situation. As a point of reference for timing, I am using a compute server but find it runs in around 2 hours for a single subject with 10 GB of memory (4 conditions, 8 runs, about 250.000 voxels in a grey matter mask).

It's messy code, as I'm currently midway through analysis, but I'll come back when I have time and clean it up a little. The n_conds and n_sessions input could be avoided by simply reading the length unique elements in of 'conds' or 'sessions' in obs_desc. As with normal crossnobis, the input in obs_desc should be ordered so that session and condition tags match the order of the observations in the data itself.

While I'm here, thank you for all the hard work you and team have put in to rsatoolbox - it's a superb tool, and has made my life a lot easier (and my analyses more reproducible!).

def get_searchlight_RDMs_crossvalidated(data_2d, centers, neighbors, 
                                        obs_descriptor,
                                        method='crossnobis', 
                                        n_conds = 4,
                                        n_sessions = 8,
                                        verbose=True,
                                        method_cov='shrinkage_eye',
                                        residuals=None):
    """Iterates over all the searchlight centers and calculates the RDM

    Args:

        data_2d (2D numpy array): brain data,
        shape n_observations x n_channels (i.e. voxels/vertices)
        arranged as [cond1, cond1, cond1, ..., cond2, cond2, cond2, ...]

        centers (1D numpy array): center indices for all searchlights as provided
        by rsatoolbox.util.searchlight.get_volume_searchlight

        neighbors (list): list of lists with neighbor voxel indices for all searchlights
        as provided by rsatoolbox.util.searchlight.get_volume_searchlight

        obs_descriptor (dictionary): a dictionary wth {'conds': [cond1, cond1, ...cond4, cond4]}
        and {'sessions': [session1, session2, ...]}

        method (str, optional): distance metric,
        see rsatoolbox.rdm.calc for options. Defaults to 'correlation'.

        n_conds: number of conditions in the design, used to calculate size of RDM.

        n_sessions: number of sessions in the design, used to calculate number of folds for 
        cross-validation and to index noise covariance matrices.

        verbose (bool, optional): Defaults to True.

        method_cov (str, optional): method to calculate covariance matrix,
        see rsatoolbox.data.noise for details. Options are 'diag', 'shrinkage_eye',
        'shrinkage_diag'. Defaults to None.

        residuals (2D numpy array, optional): residuals of the data,
        shape n_observations x n_channels. If present, used to calculate the
        covariance matrix. Defaults to None.

    Returns:
        RDM [rsatoolbox.rdm.RDMs]: RDMs object with the RDM for each searchlight
                              the RDM.rdm_descriptors['voxel_index']
                              describes the center voxel index each RDM is associated with
    """

    data_2d, centers = np.array(data_2d), np.array(centers) # in my case, data should be 32 * voxels
    n_centers = centers.shape[0] # number of searchlight centers
    # iterate over all searchlight centers
    # this is slow, and there's definitely a better way to do this
    RDM = np.zeros((n_centers, n_conds * (n_conds - 1) // 2)) # create empty RDM array for each point in the searchlight
    #RDM = []
    center_data = []
    for c in range(n_centers):
        print('Calculating RDM for searchlight center %d/%d' % (c + 1, n_centers))
        # grab this center and neighbors
        center = centers[c]
        nb = neighbors[c]
        # create a database object with this data
        ds = Dataset(data_2d[:, nb],
                        descriptors={'center': c},
                        obs_descriptors=obs_descriptor,
                        channel_descriptors={'voxels': nb})
        # add in residuals if we have them
        # calculate covariance matrix
        if residuals is not None:
            cov_dict = {}
            # use residuals to calculate covariance matrix
            # residuals are n_sessions (8) x n_observations x n_channels
            # grab residuals for this center and neighbors
            residuals_c = residuals[:,:, nb]
            for run in range(n_sessions):
                residuals_selection = residuals_c[run, :,:]
                cov_dict[run] = prec_from_residuals(residuals_selection,
                                                method=method_cov)
                print('covariance matrix for run ', run)
                print('shape of covariance matrix: ', cov_dict[run].shape)
        else:
            cov_dict[run] = prec_from_measurements(data_2d[:, nb],
                                                    method=method_cov)

        # use residuals to calculate neural RDM for searchlight
        RDM_corr = calc_rdm(ds, method=method,
                                descriptor='conds', 
                                cv_descriptor='sessions', 
                                noise=cov_dict).dissimilarities
        RDM[c] = RDM_corr

    SL_rdms = RDMs(np.array(RDM),
                   rdm_descriptors={'voxel_index': centers},
                   dissimilarity_measure=method)

    return SL_rdms
JasperVanDenBosch commented 1 year ago

That's excellent - the performance sounds good actually - thanks for sharing! I'll leave this here for future reference. Do let us know if there's anything else we can improve on.