mwaskom / seaborn

Statistical data visualization in Python
https://seaborn.pydata.org
BSD 3-Clause "New" or "Revised" License
12.18k stars 1.89k forks source link

inconsistent ability to use palette to set alpha #3705

Open Gabriel-Kissin opened 1 month ago

Gabriel-Kissin commented 1 month ago

Some plotting functions allow using palette to set the alpha value for a hue - for example, sns.scatterplot.
Other plotting functions however do not allow this - for example, sns.violinplot, sns.barplot.

Demo:

import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns

iris = sns.load_dataset("iris")

colour_dict = {'setosa':         to_rgba('blue', 0.1),
               'versicolor':     to_rgba('blue', 1.0),
               'virginica':      to_rgba('blue', 0.5)}

fig, axs = plt.subplots(ncols=3, figsize=(15,4), sharey=True)
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', hue='species', palette=colour_dict, ax=axs[0])
sns.barplot(    data=iris,                  y='sepal_length', hue='species', palette=colour_dict, ax=axs[1])
sns.violinplot( data=iris,                  y='sepal_length', hue='species', palette=colour_dict, ax=axs[2])
plt.show()

image

Suggest that all plotting functions allow this, to achieve consistency, and because this can be useful (see https://stackoverflow.com/questions/66667334/python-seaborn-alpha-by-hue).

This may be tricky with e.g. sns.lineplot which apart from the main line also plots a band with a low alpha to show confidence interval. In this case, the specified alpha can be applied to the line, and the band can stay at the default alpha; the line will still be clear due to the additive nature of alpha. Alternatively, at least those plotting functions which don't use alpha could implement this.

https://github.com/mwaskom/seaborn/issues/1966 may be related.

Many thanks for the terrific library :)

thuiop commented 1 month ago

Well, there is an easy way to deal with that, which is to use the objects API (although this will not help you with the violins). Otherwise the issue you cited sums up the issues with that.

mwaskom commented 1 month ago

I think scatterplot is probably the odd one out here and I would be surprised if that really worked almost anywhere else in the library.

Handling alpha within the color specification is a huge huge huge headache. It works a little bit better in the objects interface but I have very little interest in trying to make it retroactively work in the classic interface, unfortunately.

jhncls commented 1 month ago

When elements don't overlap, alpha has the effect of mixing (linear interpolation) between the background color and the given color. With a white background, you can manually interpolate with white (rgb (1, 1, 1)):

import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb
import seaborn as sns
import numpy as np

def whiten(color, factor):
    return np.array(to_rgb(color)) * factor + (1 - factor)

iris = sns.load_dataset("iris")

colour_dict = {'setosa':         whiten('blue', 0.1),
               'versicolor':     whiten('blue', 1.0),
               'virginica':      whiten('blue', 0.5)}

fig, axs = plt.subplots(ncols=3, figsize=(15,4), sharey=True)
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', hue='species', palette=colour_dict, ax=axs[0])
sns.barplot(    data=iris,                  y='sepal_length', hue='species', palette=colour_dict, ax=axs[1])
sns.violinplot( data=iris,                  y='sepal_length', hue='species', palette=colour_dict, ax=axs[2])
plt.show()

image

When elements do overlap, as with some scatter plots, using alpha will make overlapping elements darker. Often, you'll have to fine-tune the amount of alpha needed (and the element size) to get a balanced plot.

import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns
import numpy as np
import pandas as pd

np.random.seed(42)
dfa = pd.DataFrame({'x': np.random.randn(1000, 9).cumsum(axis=0).ravel(),
                    'y': np.random.randn(1000, 9).cumsum(axis=0).ravel(),
                    'h': 'a'})
dfb = pd.DataFrame({'x': 10 + np.random.randn(1000, 1).cumsum(axis=0).ravel(),
                    'y': np.random.randn(1000, 1).cumsum(axis=0).ravel(),
                    'h': 'b'})
df = pd.concat([dfa, dfb], ignore_index=True)

palette = {'a': to_rgba('red', 0.1), 'b': to_rgba('blue', 0.3), }
sns.scatterplot(data=df, x='x', y='y', hue='h', palette=palette)

plt.show()

image

Gabriel-Kissin commented 4 weeks ago

Many thanks @jhncls for the ingenious solution.

Here are plots made using different seaborn functions.

The plots make it clear when setting alpha works well, and when whiten works.

image

Here is the code for generating the plots

import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns
import numpy as np
import pandas as pd

n = 50

np.random.seed(0)

df = pd.DataFrame(np.random.randn(n*2, 2)+1, columns=['x','y'])
df['x_bin'] = pd.cut(df['x'],4)
df['h'] = (['a']*n) + (['b']*n)

def whiten(color, factor):
    return np.array(to_rgba(color)) * factor + (1 - factor)

palette_a = {'a':         'red'      , 'b':         'blue',       }
palette_b = {'a': to_rgba('red', 0.5), 'b': to_rgba('blue', 0.5), }
palette_c = {'a': whiten( 'red', 0.5), 'b': whiten( 'blue', 0.5), }
palettes = [palette_a, palette_b, palette_c]

ncols = len(palettes)
nrows = 17
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(3*ncols, 1.5*nrows))
legend = False

for col_j, palette in enumerate(palettes):
    row_i = 0

    sns.kdeplot(data=df, x='x', y='y', hue='h', palette=palette, linewidths=5, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('2d kdeplot')

    row_i += 1
    sns.ecdfplot(data=df, y='y', hue='h', palette=palette, lw=10, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('ecdfplot')

    row_i += 1
    sns.scatterplot(data=df, x='x', y='y', hue='h', palette=palette, s=200, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('scatterplot')

    row_i += 1
    sns.lineplot(data=df, x='x', y='y', hue='h', palette=palette, lw=5, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('lineplot')

    row_i += 1
    sns.pointplot(data=df, x='x_bin', y='y', hue='h', palette=palette, lw=5, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('pointplot')

    row_i += 1
    sns.stripplot(data=df, x='x_bin', y='y', hue='h', palette=palette, s=15, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('stripplot')

    row_i += 1
    # the following don't work with alpha but do work with whiten

    row_i += 1
    sns.kdeplot(data=df, x='x', hue='h', palette=palette, lw=5, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('1d kdeplot')

    row_i += 1
    sns.barplot(data=df, x='x_bin', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('barplot')

    row_i += 1
    sns.histplot(data=df, x='x', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('1d histplot')

    row_i += 1
    sns.countplot(data=df, x='x_bin',  hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('countplot')

    row_i += 1
    sns.boxplot(data=df, x='x_bin', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('boxplot')

    row_i += 1
    sns.boxenplot(data=df, x='x_bin', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('boxenplot')

    row_i += 1
    sns.violinplot(data=df, x='x_bin', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('violinplot')

    row_i += 1
    sns.rugplot(data=df, x='x', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('rugplot')

    row_i += 1
    # the following don't work with whiten either

    row_i += 1
    sns.histplot(data=df, x='x', y='y', hue='h', palette=palette, ax=axs[row_i, col_j], legend=legend)
    axs[row_i, col_j].set_title('2d histplot')

for ax in axs.flatten():
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')

plt.tight_layout()
plt.show()

Going through all the seaborn plotting functions helped me to discover, for the first time, the hidden gem which is sns.dogplot :-)

mwaskom commented 4 weeks ago

woof woof

jhncls commented 4 weeks ago

If you just want to use the same alpha for all elements, many functions accept a fixed alpha parameter:

import matplotlib.pyplot as plt
import seaborn as sns

iris = sns.load_dataset("iris")

colour_dict = {'setosa': 'dodgerblue',
               'versicolor': 'crimson',
               'virginica': 'limegreen'}

fig, axs = plt.subplots(ncols=3, figsize=(15, 4), sharey=True)
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', hue='species', palette=colour_dict, alpha=0.5, ax=axs[0])
sns.barplot(data=iris, y='sepal_length', hue='species', palette=colour_dict, alpha=0.5, ax=axs[1])
sns.violinplot(data=iris, y='sepal_length', hue='species', palette=colour_dict, alpha=0.5, ax=axs[2])
plt.show()

image