angelolab / ark-analysis

Integrated pipeline for multiplexed image analysis
https://ark-analysis.readthedocs.io/en/latest/
MIT License
69 stars 25 forks source link

Visualize all markers and cluster names in cell cluster heatmap #1146

Open janinemelsen opened 2 weeks ago

janinemelsen commented 2 weeks ago

Hi! in the cell clustering notebook I am trying to make a heatmap, but not all columns and rows are assigned a label:

image

Is there a way to adjust for instance the width and height of the heatmap, in order to visualize all labels?

Thanks!

Janine

cliu72 commented 1 week ago

Hi @janinemelsen! Thanks for bringing this to our attention. It seems like the default behavior of seaborn is to only show a subset of labels when it thinks there are too many labels. To fix this, you can add the yticklabels=True argument to the seaborn function.

Here is a temporary fix (all I did was take our existing functions at https://github.com/angelolab/ark-analysis/blob/main/src/ark/phenotyping/weighted_channel_comp.py#L414-L498 and https://github.com/angelolab/ark-analysis/blob/main/src/ark/analysis/visualize.py#L72-L153, and add the yticklabels=True argument to sns.clustermap. You should be able to copy and paste this code into your jupyter notebook and run:

import scipy.stats as stats
import numpy as np
import seaborn as sns
import matplotlib.patches as patches

def temp_draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None, min_val=None, max_val=None,
                 cbar_ticks=None, colormap="vlag", row_colors=None, row_cluster=True,
                 col_colors=None, col_cluster=True, left_start=None, right_start=None,
                 w_spacing=None, h_spacing=None, save_dir=None, save_file=None):
    """Plots the z scores between all phenotypes as a clustermap.

    Args:
        data (numpy.ndarray):
            The data array to visualize
        x_labels (list):
            List of names displayed on horizontal axis
        y_labels (list):
            List of all names displayed on vertical axis
        dpi (float):
            The resolution of the image to save, ignored if save_dir is None
        center_val (float):
            value at which to center the heatmap
        min_val (float):
            minimum value the heatmap should take
        max_val (float):
            maximum value the heatmap should take
        cbar_ticks (int):
            list of values containing tick labels for the heatmap colorbar
        colormap (str):
            color scheme for visualization
        row_colors (list):
            Include these values as an additional color-coded cluster bar for row values
        row_cluster (bool):
            Whether to include dendrogram clustering for the rows
        col_colors (list):
            Include these values as an additional color-coded cluster bar for column values
        col_cluster (bool):
            Whether to include dendrogram clustering for the columns
        left_start (float):
            The position to set the left edge of the figure to (from 0-1)
        right_start (float):
            The position to set the right edge of the figure to (from 0-1)
        w_spacing (float):
            The amount of spacing to put between the subplots width-wise (from 0-1)
        h_spacing (float):
            The amount of spacing to put between the subplots height-wise (from 0-1)
        save_dir (str):
            If specified, a directory where we will save the plot
        save_file (str):
            If save_dir specified, specify a file name you wish to save to.
            Ignored if save_dir is None
    """

    # Replace the NA's and inf values with 0s
    data[np.isnan(data)] = 0
    data[np.isinf(data)] = 0

    # Assign numpy values respective phenotype labels
    data_df = pd.DataFrame(data, index=x_labels, columns=y_labels)
    sns.set(font_scale=.7)

    heatmap = sns.clustermap(
        data_df, cmap=colormap, center=center_val,
        vmin=min_val, vmax=max_val, row_colors=row_colors, row_cluster=row_cluster,
        col_colors=col_colors, col_cluster=col_cluster,
        cbar_kws={'ticks': cbar_ticks},
        yticklabels=True
    )

    # ensure the row color axis doesn't have a label attacked to it
    if row_colors is not None:
        _ = heatmap.ax_row_colors.xaxis.set_visible(False)

    if col_colors is not None:
        _ = heatmap.ax_col_colors.yaxis.set_visible(False)

    # update the figure dimensions to accommodate Jupyter widget backend
    _ = heatmap.gs.update(
        left=left_start, right=right_start, wspace=w_spacing, hspace=h_spacing
    )

    # ensure the y-axis labels are horizontal, will be misaligned if vertical
    _ = plt.setp(heatmap.ax_heatmap.get_yticklabels(), rotation=0)

    plt.tight_layout()

    if save_dir is not None:
        misc_utils.save_figure(save_dir, save_file, dpi=dpi)

def temp_generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cluster_col,
                                          channels, raw_cmap, renamed_cmap,
                                          center_val=0, min_val=-3, max_val=3):
    """Generates a z-scored heatmap of the average weighted channel expression per cell cluster

    Args:
        cell_cluster_channel_avg_path (str):
            Path to the file containing the average weighted channel expression per cell cluster
        cell_cluster_col (str):
            The name of the cell cluster col,
            needs to be either 'cell_som_cluster' or 'cell_meta_cluster_rename'
        channels (str):
            The list of channels to visualize
        raw_cmap (dict):
            Maps the raw meta cluster labels to their respective colors,
            created by `generate_meta_cluster_colormap_dict`
        renamed_cmap (dict):
            Maps the renamed meta cluster labels to their respective colors,
            created by `generate_meta_cluster_colormap_dict`
        center_val (float):
            value at which to center the heatmap
        min_val (float):
            minimum value the heatmap should take
        max_val (float):
            maximum value the heatmap should take
    """

    # file path validation
    io_utils.validate_paths([cell_cluster_channel_avg_path])

    # verify the cell_cluster_col provided is valid
    misc_utils.verify_in_list(
        provided_cluster_col=[cell_cluster_col],
        valid_cluster_cols=['cell_som_cluster', 'cell_meta_cluster_rename']
    )

    # read the channel average path
    cell_cluster_channel_avgs = pd.read_csv(cell_cluster_channel_avg_path)

    # assert the channels provided are valid
    misc_utils.verify_in_list(
        provided_channels=channels,
        channel_avg_cols=cell_cluster_channel_avgs.columns.values
    )

    # sort the data by the meta cluster value
    # this ensures the meta clusters are grouped together when the colormap is displayed
    cell_cluster_channel_avgs = cell_cluster_channel_avgs.sort_values(
        by='cell_meta_cluster_rename'
    )

    # map raw_cmap onto cell_cluster_channel_avgs for the heatmap to display the side color bar
    meta_cluster_index = cell_cluster_channel_avgs[cell_cluster_col].values
    meta_cluster_mapping = pd.Series(
        cell_cluster_channel_avgs['cell_meta_cluster_rename']
    ).map(renamed_cmap)
    meta_cluster_mapping.index = meta_cluster_index

    # draw the heatmap
    temp_draw_heatmap(
        data=stats.zscore(cell_cluster_channel_avgs[channels].values),
        x_labels=cell_cluster_channel_avgs[cell_cluster_col],
        y_labels=channels,
        center_val=center_val,
        min_val=min_val,
        max_val=max_val,
        cbar_ticks=np.arange(-3, 4),
        row_colors=meta_cluster_mapping,
        row_cluster=False,
        left_start=0.0,
        right_start=0.85,
        w_spacing=0.2,
        colormap='vlag'
    )

    # add the legend
    handles = [patches.Patch(facecolor=raw_cmap[mc]) for mc in raw_cmap]
    _ = plt.legend(
        handles,
        renamed_cmap,
        title='Meta cluster',
        bbox_to_anchor=(1, 1),
        bbox_transform=plt.gcf().transFigure,
        loc='upper right'
    )
temp_generate_weighted_channel_avg_heatmap(
    os.path.join(base_dir, cell_meta_cluster_channel_avg_name),
    'cell_meta_cluster_rename',
    channels,
    raw_cmap,
    renamed_cmap
)

Let me know if this works!

cliu72 commented 1 week ago

@alex-l-kong can we add yticklabels=True as a permanent fix to sns.clustermap in vizualize.draw_heatmap? I think it makes sense for that to be the default behavior.