mwaskom / seaborn

Statistical data visualization in Python
https://seaborn.pydata.org
BSD 3-Clause "New" or "Revised" License
12.44k stars 1.91k forks source link

scatterplot crashes or produces unexpected results when {`hue`,`style`}`_order` do not contain all `hue`/`style` values #3601

Open mmore500 opened 9 months ago

mmore500 commented 9 months ago

bug description: scatterplot plots hollow points when hue_order is a strict subset of hue values within dataframe and crashes when style_order is a strict subset of style values with in dataframe (i.e., the dataframe contains hue or style values not present in hue_order/style_order). Was able to reproduce in sns versions 0.13.0, 0.12.0, and 0.11.0

expected behavior: scatterplot would should plot the subset of data with values specified in hue_order and style_order, like current behavior of lineplot, kdeplot, etc.

related issues: none obvious, #3575 has different stack trace and does not occur with 0.12.x versions of seaborn

if wanted, I'd be happy to look into contributing a fix

scatterplot: plots hollow points when set(hue_order) < set(df[hue])

sns.scatterplot(
    data=sns.load_dataset("diamonds").dropna(),
    x="price",
    y="carat",
    hue="clarity",
    hue_order=["I1", "IF"],
)
plt.gca().set_facecolor("gray")

image

scatterplot: crashes when set(style_order) < set(df[style])

sns.scatterplot(
    data=sns.load_dataset("diamonds").dropna(),
    x="price",
    y="carat",
    style="clarity",
    style_order=["I1", "IF"],
)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mmore500/2023-12-09/outset/env210/lib64/python3.10/site-packages/seaborn/relational.py", line 624, in scatterplot
    p.plot(ax, kwargs)
  File "/home/mmore500/2023-12-09/outset/env210/lib64/python3.10/site-packages/seaborn/relational.py", line 443, in plot
    p = [self._style_map(val, "path") for val in data["style"]]
  File "/home/mmore500/2023-12-09/outset/env210/lib64/python3.10/site-packages/seaborn/relational.py", line 443, in <listcomp>
    p = [self._style_map(val, "path") for val in data["style"]]
  File "/home/mmore500/2023-12-09/outset/env210/lib64/python3.10/site-packages/seaborn/_base.py", line 85, in __call__
    return self._lookup_single(key, *args, **kwargs)
  File "/home/mmore500/2023-12-09/outset/env210/lib64/python3.10/site-packages/seaborn/_base.py", line 588, in _lookup_single
    value = self.lookup_table[key][attr]
KeyError: 'SI2'

for comparison, lineplot, kdeplot, and lmplot have a more expected behavior

lineplot: works as expected with set(style_order) < set(df[style]) and set(hue_order) < set(df[hue])

sns.lineplot(
    data=sns.load_dataset("diamonds").dropna(),
    x="price",
    y="carat",
    hue="clarity",
    hue_order=["I1", "IF"],>
    style="clarity",
    style_order=["I1", "IF"],
)
plt.gca().set_facecolor("gray")

image

kdeplot: works as expected when set(hue_order) < set(df[hue])

sns.kdeplot(
   data=sns.load_dataset("diamonds").dropna(),
   x="price",
   y="carat",
   hue="clarity",
   hue_order=["I1", "IF"],
)
plt.gca().set_facecolor("gray")

image

lmplot: works as expected when set(hue_order) < set(df[hue])

sns.lmplot(
   data=sns.load_dataset("diamonds").dropna(),
   x="price",
   y="carat",
   hue="clarity",
   hue_order=["I1", "IF"],
)
plt.gca().set_facecolor("gray")

image

system information

seaborn v0.13.0, I was also able to reproduce on v0.12.0 and v0.11.0

Python 3.10.13 (main, Aug 28 2023, 00:00:00) [GCC 13.2.1 20230728 (Red Hat 13.2.1-1)] on linux

pip freeze:

contourpy==1.2.0
cycler==0.12.1
fonttools==4.47.0
kiwisolver==1.4.5
matplotlib==3.8.2
numpy==1.26.2
packaging==23.2
pandas==2.1.4
Pillow==10.1.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
seaborn==0.13.0
six==1.16.0
tzdata==2023.3
          /:-------------:\          mmore500@fedora
       :-------------------::        OS: Fedora 
     :-----------/shhOHbmp---:\      Kernel: x86_64 Linux 6.6.6-200.fc39.x86_64
   /-----------omMMMNNNMMD  ---:     Uptime: 19h 41m
  :-----------sMMMMNMNMP.    ---:    Packages: 6979
 :-----------:MMMdP-------    ---\   Shell: pkcommand-not-found
,------------:MMMd--------    ---:   Resolution: 1920x1080
:------------:MMMd-------    .---:   DE: GNOME 45.2
:----    oNMMMMMMMMMNho     .----:   WM: Mutter
:--     .+shhhMMMmhhy++   .------/   WM Theme: Nordic
:-    -------:MMMd--------------:    GTK Theme: Nordic [GTK2/3]
:-   --------/MMMd-------------;     Icon Theme: Adwaita
:-    ------/hMMMy------------:      Font: Cantarell 11
:-- :dMNdhhdNMMNo------------;       Disk: 178G / 1.9T (10%)
:---:sdNMMMMNds:------------:        CPU: Intel Core i7-10510U @ 8x 4.9GHz [64.0°C]
:------:://:-------------::          GPU: Mesa Intel(R) UHD Graphics (CML GT2)
:---------------------://            RAM: 12554MiB / 15649MiB
mmore500 commented 9 months ago

In case anyone else also needs it, I've put the patch I'm using in the meantime up at https://github.com/mmore500/outset/blob/1dd0d036f90bc7b27f20c5fca3f6eb257e70770c/src/outset/patched/_scatterplot.py

thuiop commented 9 months ago

This is not an obvious one. In the case of lines (for instance), the dataframe is grouped by subset of color/linestyle/... combination, and for each subset a line is drawn (and set to the correct aspect) ; it is not the case with scatter, where all the points are plotted at the same time, and then the colors and styles are mapped to each point. This probably requires rewriting the function for scatter plots in the same fashion as for line plots.

mwaskom commented 9 months ago

The reason that scatterplot works differently from lineplot and others is that if we grouped over the hue/size/style variables the resulting scatterplot would be "layered" in a way that could be misleading. Additionally, the number of individual collection artists generated may be very large (imagine a dense scatterplot with a continuous hue variable where most hue observations are distinct).

It also looks like a similar issue occurs in stripplot (although stripplot dots do not have edges by default):

sns.stripplot(diamonds, x="cut", y="price", hue="clarity", linewidth=.3, hue_order=["I1", "IF"])

Probably the solution is to reduce the dataframe to just the rows with the relevant values for hue/size/style at some point, either centrally, or in the plotting method for scatter-type plots. Given that it occurs in more than one place, centrally makes sense, but there may be some complications relating to choosing default scale domains and/or computing statistics that we'd need to be mindful of.

Note that the objects interface does not have this issue:

import seaborn.objects as so
(
    so.Plot(diamonds, x="carat", y="price", color="clarity")
    .add(so.Dots())
    .scale(color=so.Nominal(order=["IF", "I1"]))
)