int-brain-lab / iblatlas

IBL atlas module
MIT License
1 stars 1 forks source link

Annotate plot_swanson_vector such that Beryl regions are printed on all lower Allen ones #11

Open mschart opened 4 months ago

mschart commented 4 months ago

Some regions are not annotated in plot_swanson_vector, in iblatlas/plots.py.

plot_swanson_vector(np.array(['MRN', 'IRN']),np.array([0.5,0.5]), annotate=True) works nicely, both regions get an acronym printed on top.

plot_swanson_vector(np.array(['SCig', 'SCiw']),np.array([0.5,0.5]), annotate=True) doesn't, although these regions show up when doing plot_swanson_vector(annotate=True).

Beryl regions do sometimes show up, for example not in this case: plot_swanson_vector(np.array(['SCm', 'SCs']),np.array([0.5,0.5]), annotate=True)

image

It would be nice if for example in the latter case - for hierarchical level Beryl, which most of us use - the SCm and SCs labels are print on top of multiple regions (say 'SCig', 'SCiw' are both part of 'SCm').

mschart commented 2 months ago

Some other oddity with the plot_swanson_vector. Consider this, I plot for each analysis the 5 top values as text next to the swanson and also use the annotate argument in it. They mainly match, but not always. Say consider the thirs Swanson, which shows SOC instead of GRN (which it should when doing an argsort of the scores and using this as indices for the acronyms strings):

image

mschart commented 2 months ago
def plot_swansons(variable, fig=None, axs=None):

    '''
    for a single variable, plot 5 results swansons,
    4 effects for the 4 analyses and latencies for manifold
    '''

    res = pd.read_pickle(meta_pth / f"{variable}.pkl")

    lw = 0.1  # .01

    # results to plot in Swansons with labels for colorbars
    res_types = {'decoding_effect': ['Decoding. $R^2$ over null',[], 
                    ['Decoding', 'Regularized logistic regression']],
                 'mannwhitney_effect': ['Frac. sig. cells',[],
                    ['Single cell statistics', 'C.C Mann-Whitney test']],
                 'euclidean_effect': ['Nrml. Eucl. dist. (log)',[],
                    ['Manifold', 'Distance between trajectories']],
                 'euclidean_latency': ['Latency of dist. (sec)',[],
                    ['Manifold', 'Time near peak']],      
                 'glm_effect': ['Abs. diff. $\\Delta R^2$ (log)',[],
                    ['Encoding', 'General linear model']]}

    cmap = get_cmap_(variable)

    alone = False
    if not fig:
        fig = plt.figure(figsize=(8,3.34), layout='constrained')  
        gs = gridspec.GridSpec(1, len(res_types), 
                               figure=fig,hspace=.75)
        axs = []
        alone = True

    k = 0
    for res_type in res_types:
        if alone:
            axs.append(fig.add_subplot(gs[0,k]))

        ana = res_type.split('_')[0]
        lat = True if 'latency' in res_type else False
        dt = 'effect' if not lat else 'latency'

        if ana != 'glm':

            # check if there are p-values and mask
            acronyms = res[res[f'{ana}_significant'] == True][
                        'region'].values
            scores = res[res[
                        f'{ana}_significant'] == True][
                        f'{ana}_{dt}'].values

            if lat:            
                mask = res[np.bitwise_or(
                            res[f'{ana}_significant'] == False,
                            np.isnan(res[f'{ana}_{dt}']))][
                            'region'].values               
            else:            

                # remove regs from mask with nan amps (not analyzed)            
                mask = res[np.bitwise_and(
                            res[f'{ana}_significant'] == False,
                            ~np.isnan(res[f'{ana}_{dt}']))][
                            'region'].values

        else:
            acronyms = res['region'].values
            scores = res[f'{ana}_effect'].values
            mask = [] 

        plot_swanson_vector(acronyms,
                            scores,
                            hemisphere=None, 
                            orientation='portrait',
                            cmap=cmap.reversed() if lat else cmap,
                            br=br,
                            ax=axs[k],
                            empty_color="white",
                            linewidth=lw,
                            mask=mask,
                            mask_color='silver',
                            annotate= True,
                            annotate_n=5,
                            annotate_order='bottom' if lat else 'top',
                            fontsize=f_size_s)

        clevels = (min(scores), max(scores))    
        norm = mpl.colors.Normalize(vmin=clevels[0], vmax=clevels[1])
        cbar = fig.colorbar(
                   mpl.cm.ScalarMappable(norm=norm,cmap=cmap.reversed() 
                   if lat else cmap),
                   ax=axs[k],shrink=0.4,aspect=12,pad=.025,
                   orientation="horizontal")

        ticks = np.round(np.linspace(cbar.vmin, cbar.vmax, num=3), 2)             
        cbar.set_ticks(ticks)
        cbar.outline.set_visible(False)
        cbar.ax.xaxis.set_tick_params(pad=5)
        cbar.set_label(res_types[res_type][0], fontsize=f_size_s)
        cbar.ax.tick_params(labelsize=f_size_s)

        axs[k].text(-0.25, 0.5, res_types[res_type][2][0],
                fontsize=f_size, ha='center',va = 'center', 
                rotation='vertical', 
                transform=axs[k].transAxes)
        axs[k].text(-0.1, .5, res_types[res_type][2][1],
                fontsize=f_size_s, ha='center',va = 'center',
                rotation='vertical', 
                transform=axs[k].transAxes)
        axs[k].text(0.85, 0.95, f' {len(scores)}/{len(scores) + len(mask)}',
                fontsize=f_size_s, ha='center', 
                transform=axs[k].transAxes)                

        # print regions with largest (smallest) amp (lat) scores 
        if lat:
            exxregs = acronyms[np.argsort(scores)][:5]
        else:
            exxregs = acronyms[np.argsort(scores)][-5:]

        print('highlight regs')
        print(exxregs)
        for i, text in enumerate(exxregs[:3]):
            axs[k].text(-0.2, 0.2 - i * 0.05,  
            text, fontsize=f_size_s, 
            ha='left', va='top',
            transform=axs[k].transAxes)

        for i, text in enumerate(exxregs[3:]):
            axs[k].text(0.85, 0.1 - i * 0.05,  
            text, fontsize=f_size_s, 
            ha='left', va='top',
            transform=axs[k].transAxes)                                

        axs[k].axis("off")

        axs[k].axes.invert_xaxis()

        k += 1  

    if alone:
        fig.savefig(Path(imgs_pth, variable, 'swansons.svg'))