YosefLab / Cassiopeia

A Package for Cas9-Enabled Single Cell Lineage Tracing Tree Reconstruction
https://cassiopeia-lineage.readthedocs.io/en/latest/
MIT License
75 stars 24 forks source link

Show legend for colorstrip when using cassiopeia.pl.plot_matplotlib #200

Closed YushaLiu closed 1 year ago

YushaLiu commented 1 year ago

Hi Matt, I have a quick question about showing legend for colorstrip when using cassiopeia.pl.plot_matplotlib. For example, when the colorstrip represents the organ locations where the cells in the tree are collected, I hope to display which color corresponds to which location in a legend in the plot produced from cassiopeia.pl.plot_matplotlib. Is it possible to do so? I tried something like: cas.pl.plot_matplotlib(cas_tree, add_root=True, meta_data=['sampleID'], colorstrip_kwargs=dict(showlegend=True)) but it doesn't work and returns an error. Thanks very much!

mattjones315 commented 1 year ago

Hi @YushaLiu --

Thanks for raising this issue! Currently we do not have functionality to plot the legend as well, the reason being that you might pass in several color strips in which case the legend would be ambiguous. I'm not sure what the best fix for this would be. However, you should be able to add a legend to an existing plot if you only have one bar. Here are two examples, one for colorbars (continuous data) and one for legends (categorical data):

Continuous data colorbar

import matplotlib as mpl

min_value = tree.cell_meta['TS-UMI'].min()
max_value = tree.cell_meta['TS-UMI'].max()

cas.pl.plot_matplotlib(
    tree,
    add_root=True,
    indel_priors=priors,
    continuous_cmap='Reds',
    meta_data=['TS-UMI'],
    colorstrip_width=2
)
ax = plt.gca()
ax.set_aspect(1.0)

norm_range = mpl.colors.Normalize(vmin=min_value, vmax=max_value, clip=False)
plt.colorbar(mpl.cm.ScalarMappable(norm=norm_range, cmap='Reds'), ax=ax, label='Target site UMIs')
plt.show()

image

Categorical data legend

import matplotlib as mpl
from matplotlib.lines import Line2D

cm = plt.cm.get_cmap('tab10')
unique_values = set(tree.cell_meta['cluster'].values)
value_mapping = {val: i for i, val in enumerate(unique_values)}

print(groups)

cas.pl.plot_matplotlib(
    tree,
    add_root=True,
    indel_priors=priors,
    categorical_cmap='tab10',
    meta_data=['cluster'],
    colorstrip_width=2
)
ax = plt.gca()
ax.set_aspect(1.0)

handles=[]
for i, group in zip(range(len(groups)), unique_values):
    color = cm(value_mapping[group])[:-1]
    point = Line2D([0], [0], label=group, marker='s', markersize=10, 
             markeredgecolor='b', markerfacecolor=color, linestyle='')
    handles.append(point)

# add manual symbols to auto legend
plt.legend(handles=handles, labels=unique_values)
plt.show()

image

Hope this helps!

Best, Matt