scverse / scanpy

Single-cell analysis in Python. Scales to >1M cells.
https://scanpy.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.86k stars 594 forks source link

Legends not accessible via provided axis and misplaced (scatterplot, subplots) #1258

Open picciama opened 4 years ago

picciama commented 4 years ago

When using sc.pl.scatter() and providing an existing axis object, the legend doesn't always appear correctly and cannot be accessed. This doesn't seem to happen with a categorial coloring however, only with a continous colormap.

This code works as expected:

sc.pp.calculate_qc_metrics(adata_raw, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
sc_fig, (sc_ax1, sc_ax2) = plt.subplots(1,2, figsize=(12,5))
sc.pl.scatter(adata_raw, 'total_counts','n_genes_by_counts', color='batch', size = 10, ax=sc_ax1, show=False, title="all counts")
sc_ax1.get_legend().remove()
sc.pl.scatter(adata_raw[adata_raw.obs['total_counts']<1000],'total_counts','n_genes_by_counts', color='batch', size = 10, ax=sc_ax2, show=False, title="< 1000 counts")
plt.show()

It creates some metrics and stores them in adata_raw.obs, then plots these metrics for all counts and for counts < 1000 on the two axes created by plt.subplots(). The legend from the first axis is then removed. This is an example of this output: image

Now the code that doesn't work:

sc_fig, (sc_ax1, sc_ax2) = plt.subplots(1,2, figsize=(12,5))
sc.pl.scatter(adata_raw, 'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, ax=sc_ax1, show=False, title="all counts")
#sc_ax1.get_legend().remove()
sc.pl.scatter(adata_raw[adata_raw.obs['total_counts']<1000],'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, ax=sc_ax2, show=False, title="< 1000 counts")
plt.show()

Essentially the same thing but colored by the percentage of mitochondrial counts. Only one legend seems to be drawn and this one is not looking as expected. Plus, I cannot remove the legend from the first plot. This is how it looks: image

Why doesn't it behave in the same way like in the example above? Is there a way I can share the same legend with a scale from 0 to 1 (0%-100%) for both plots in this case? As you can see, the line removing the legend from sc_ax1 is commented out because get_legend() returns None in this case, which would lead to the error below:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-154-702da93b63cb> in <module>
      2     sc_fig, (sc_ax1, sc_ax2) = plt.subplots(1,2, figsize=(12,5))
      3     sc.pl.scatter(adata_raw, 'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, ax=sc_ax1, show=False, title="all counts")
----> 4     sc_ax1.get_legend().remove()
      5     sc.pl.scatter(adata_raw[adata_raw.obs['total_counts']<1000],'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, ax=sc_ax2, show=False, title="< 1000 counts")
      6     plt.show()

AttributeError: 'NoneType' object has no attribute 'remove'

Shouldn't the legends be attached to the individual axes objects? I cannot access them and I wonder where they are stored in this case.

Versions:

scanpy==1.5.2.dev5+ge5d246aa anndata==0.7.1 umap==0.4.3 numpy==1.18.4 scipy==1.4.1 pandas==0.25.3 scikit-learn==0.23.0 statsmodels==0.11.1 python-igraph==0.8.2 louvain==0.7.0 matplotlib==3.1.2

picciama commented 4 years ago

Additional information: Changing legend_loc doesn't have any effect here either.

diazdc commented 3 years ago

Additional information: Changing legend_loc doesn't have any effect here either.

I'm having the same issue. Is there a work around for this? Changing rcParams["legend.loc"] has no effect either.

picciama commented 3 years ago

Totally forgot about this one sorry :(

The problem In scatter_base: https://github.com/theislab/scanpy/blob/040e61ff50836d4a6cdd7da7482dcb4ee50d05ae/scanpy/plotting/_utils.py#L736-L740

For non categorical variables, this code gets the current figure and adds a separate axis on which the colorbar is plotted. Therefore, the axes objects on which the data is plotted do not contain a legend object. Instead, fig should contain the colorbar axis and we could maybe manage to manipulate it as a workaround.

There is also this DeprecationWarning popping up.

MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.
  ax_cb = fig.add_axes(rectangle)

Current workaround (also for all other sort of plots) The problem here is really that we don't have two separate figures / axes aren't handled correctly. Good news is, that there is a way around using plt.subplots and using given Axes objects. even if we want to plot 2 plots side by side in a jupyter notebook (original post here: https://stackoverflow.com/questions/21754976/ipython-notebook-arrange-plots-horizontally). However, sc.pl.scatter isn't exposing the figure object but only the axis. But if we specify show=False, it returns the axis and we can obtain the figure object using matplotlib.pyplot.gcf(). Store these figures in a list and pass them to the plot_nice() function which will plot all your figures side by side until it runs out of space, after which it will create a linebreak and continue. Therefore, you can specify how many figures you want to plot per line, using the individual figsize argument.

For my example it would look like this:

from flow_layout import plot_nice  # import the required plotting function
plots = []
sc.pl.scatter(adata_all, 'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, show=False, title="all counts")
plots.append(plt.gcf())
sc.pl.scatter(adata_all[adata_all.obs['total_counts']<1000],'total_counts','n_genes_by_counts', color='pct_counts_mt', size = 10, show=False, title="< 1000 counts") 
plots.append(plt.gcf())
plot_nice(plots)  # plot all figures

Result: Bildschirmfoto von 2020-10-21 12-52-22

In order to import the function, place the following code in a file called flow_layout.py in the same folder as your notebooks:

class FlowLayout(object):
    ''' A class / object to display plots in a horizontal / flow layout below a cell '''

    def __init__(self):
        # string buffer for the HTML: initially some CSS; images to be appended
        self.sHtml = """
        <style>
        .floating-box {
        display: inline-block;
        margin: 10px;
        # border: 3px solid #888888;
        }
        </style>
        """

    def add_plot(self, oAxes):
        ''' Saves a PNG representation of a Matplotlib Axes object '''
        Bio = io.BytesIO()  # bytes buffer for the plot
        if not isinstance(oAxes, matplotlib.figure.Figure):
            fig = oAxes.fig
        else:
            fig = oAxes
        fig.canvas.print_png(Bio)  # make a png of the plot in the buffer

        # encode the bytes as string using base 64
        sB64Img = base64.b64encode(Bio.getvalue()).decode()
        self.sHtml += (
            '<div class="floating-box">'
            + '<img src="data:image/png;base64,{}\n">'.format(sB64Img)
            + '</div>')

    def PassHtmlToCell(self):
        ''' Final step - display the accumulated HTML '''
        display(HTML(self.sHtml))

def plot_nice(plots: list):
    oPlot = FlowLayout()
    for fig in plots:
        oPlot.add_plot(fig)
        matplotlib.pyplot.close()
    oPlot.PassHtmlToCell()

Hope this is useful :)

diazdc commented 3 years ago

Thanks!

SeppeDeWinter commented 3 years ago

The legend can also be removed by removing the last axis from the figure.

sc_fig.axes[-1].remove()

mjseignon commented 2 years ago

Is this still the only fix for the colorbar overlapping with the scatter plots?

JackieMium commented 1 year ago

I am surprised, it's year 2023 with scanpy==1.9.3, and I am facing exactly the same issue posted in 2020. Really hope it will be dealt with in a future version. Thanks!

flying-sheep commented 1 year ago

We will (hopefully) redesign how plotting works before the end of the year. Until then, you can use the workaround in https://github.com/scverse/scanpy/issues/1258#issuecomment-713492283.

ramadatta commented 8 months ago

I had some issue with io.BytesIO() from the fix proposed above.

So, I used R to generate scatter plots as below:

import anndata2ri
import logging

import rpy2.rinterface_lib.callbacks as rcb
import rpy2.robjects as ro

rcb.logger.setLevel(logging.ERROR)
ro.pandas2ri.activate()
anndata2ri.activate()

%load_ext rpy2.ipython

Convert adata_p and adata_g to R objects

ro.globalenv['r_adata_p'] = adata_p
ro.globalenv['r_adata_g'] = adata_g
%%R -w 800 -h 400 -u px

library(Seurat)
library(viridis)
library(viridisLite)
library(ggplot2)
library(cowplot)

df_poor= data.frame(
  total_counts = colData(r_adata_p)$total_counts,
  n_genes_by_counts = colData(r_adata_p)$n_genes_by_counts,
  pct_counts_mt = colData(r_adata_p)$pct_counts_mt
)

df_good= data.frame(
  total_counts = colData(r_adata_g)$total_counts,
  n_genes_by_counts = colData(r_adata_g)$n_genes_by_counts,
  pct_counts_mt = colData(r_adata_g)$pct_counts_mt
)

#head(df)
# Create a scatter plot using ggplot2
p2 <- ggplot(data = df_poor, aes(x = total_counts, y = n_genes_by_counts, color = pct_counts_mt)) +
  geom_point() +
  scale_color_viridis() +
  labs(title = "poor (after outlier and mitochrondrial gene removal)") +
  theme_minimal()

g2 <- ggplot(data = df_good, aes(x = total_counts, y = n_genes_by_counts, color = pct_counts_mt)) +
  geom_point() +
  scale_color_viridis() +
  labs(title = "good (after outlier and mitochrondrial gene removal)") +
  theme_minimal()

p2 + g2

Screenshot from 2023-12-13 11-25-03

FernandoDuarteF commented 4 months ago

Has this already been fixed? Besides the workaround in https://github.com/scverse/scanpy/issues/1258#issuecomment-713492283.