scverse / scanpy

Single-cell analysis in Python. Scales to >1M cells.
https://scanpy.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.89k stars 595 forks source link

How can I delete or rotate the legend of `sc.pl.stacked_violin`? #2530

Open danli349 opened 1 year ago

danli349 commented 1 year ago

Hello:

How can I delete or rotate or change the position of the legend of sc.pl.stacked_violin?

image Thanks

flying-sheep commented 1 year ago

We have plot classes that are used for all plots. The stacked_violin function is not more than:

https://github.com/scverse/scanpy/blob/ed3b277b2f498e3cab04c9416aaddf97eec8c3e2/scanpy/plotting/_stacked_violin.py#L679-L723

And the colorbar plotting is defined here:

https://github.com/scverse/scanpy/blob/ed3b277b2f498e3cab04c9416aaddf97eec8c3e2/scanpy/plotting/_baseplot_class.py#L493-L520

So currently, you’d have to do something like this, but I agree, this should be easier

from matplotlib import colormaps
from matplotlib.colorbar import Colorbar
from matplotlib.cm import ScalarMappable

class MyStackedViolin(StackedViolin):
     def _plot_colorbar(self, color_legend_ax, normalize):
         mappable = ScalarMappable(norm=normalize, cmap=colormaps[self.cmap])
         Colorbar(color_legend_ax, mappable=mappable, orientation='vertical')

vp = StackedViolin(adata, var_names, ...)
vp.make_figure()
danli349 commented 1 year ago

@flying-sheep Thanks for helping!
I tried your code, but it does not work as expected.

class MyStackedViolin(sc.pl.StackedViolin):
     def _plot_colorbar(self, color_legend_ax, normalize):
         mappable = ScalarMappable(norm=normalize, cmap=colormaps[self.cmap])
         Colorbar(color_legend_ax, mappable=mappable, orientation='vertical')

vp = MyStackedViolin(adata_luminals, L1_signature,
                    groupby='leiden_r1',
                    cmap='coolwarm',
                    swap_axes = True)
vp.make_figure()

image

flying-sheep commented 1 year ago

I see, there’s also code to make that exact shape. Seems like you need to override this as well:

https://github.com/scverse/scanpy/blob/ed3b277b2f498e3cab04c9416aaddf97eec8c3e2/scanpy/plotting/_baseplot_class.py#L522-L542

maybe simply

 def _plot_legend(self, legend_ax, return_ax_dict, normalize):  
     self._plot_colorbar(legend_ax, normalize) 
     return_ax_dict['color_legend_ax'] = color_legend_ax

but as said: we will start working on a more flexible and less fiddle plotting API

danli349 commented 1 year ago

@flying-sheep Thanks! Now it works better, but the size is out of control.

image When I combine several plots together using plt.subplots

from matplotlib.pyplot import rc_context
with rc_context({'figure.figsize': (3, 3)}):
    fig, ax = plt.subplots(1, 4, figsize=(12,4))
    ax[0] = sc.pl.stacked_violin(adata_luminals[adata_luminals.obs['sample'] == adata_luminals.obs['sample'].cat.categories[0]], 
                                  var_names = gene, use_raw = True, ax=ax[0],
                                  groupby='leiden_r1', 
                                  cmap='coolwarm', dendrogram=False,
                                  swap_axes = True, stripplot = False,
                                  title = adata_luminals.obs['sample'].cat.categories[0], show=False)
    ax[1] = sc.pl.stacked_violin(adata_luminals[adata_luminals.obs['sample'] == adata_luminals.obs['sample'].cat.categories[1]], 
                                  var_names = gene, use_raw = True, ax=ax[1],
                                  groupby='leiden_r1', 
                                  cmap='coolwarm', dendrogram=False,
                                  swap_axes = True, stripplot = False,
                                  title = adata_luminals.obs['sample'].cat.categories[1], show=False)
    ax[2] = sc.pl.stacked_violin(adata_luminals[adata_luminals.obs['sample'] == adata_luminals.obs['sample'].cat.categories[2]], 
                                  var_names = gene, use_raw = True, ax=ax[2],
                                  groupby='leiden_r1', 
                                  cmap='coolwarm', dendrogram=False,
                                  swap_axes = True, stripplot = False,
                                  title = adata_luminals.obs['sample'].cat.categories[2], show=False)
    ax[3] = sc.pl.stacked_violin(adata_luminals[adata_luminals.obs['sample'] == adata_luminals.obs['sample'].cat.categories[3]], 
                                  var_names = gene, use_raw = True, ax=ax[3],
                                  groupby='leiden_r1', 
                                  cmap='coolwarm', dendrogram=False,
                                  swap_axes = True, stripplot = False,
                                  title = adata_luminals.obs['sample'].cat.categories[3], show=False)
#fig.delaxes(fig.axes[11])
#fig.delaxes(fig.axes[18])
#fig.delaxes(fig.axes[25])
#fig.delaxes(fig.axes[32])
plt.draw()

image I would like to omit the color bars in the middle and only keep the last one. Could you please show me how to delete it? Thanks

flying-sheep commented 1 year ago

That should be simply groupby='sample' instead of multiple plot calls.

danli349 commented 1 year ago

Yes I would like to separate both sample and leiden_r1, I should create a new variable adata.obs['leiden+sample'].

flying-sheep commented 1 year ago

you can just do groupby=['leiden_r1', 'sample']