mwaskom / seaborn

Statistical data visualization in Python
https://seaborn.pydata.org
BSD 3-Clause "New" or "Revised" License
12.57k stars 1.93k forks source link

`size` parameter of `scatterplot` does not accept Float64 type #3519

Open mofojed opened 1 year ago

mofojed commented 1 year ago

Create a table that has dtypes Float64 and use one of the columns for the size parameter in scatterplot:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 },
    {"X": 6, "Y": -0.279415, "Z": 0.96017, "R": 0.412779, "S": 0.837251 },
    {"X": 7, "Y": 0.656987, "Z": 0.753902, "R": 0.33618, "S": 30.597325 },
    {"X": 8, "Y": 0.989358, "Z": -0.1455, "R": 0.312757, "S": 2.10432 },
    {"X": 9, "Y": 0.412118, "Z": -0.91113, "R": 0.88594, "S": 33.09462 }
])
df_sb_multi = df_sb_multi.convert_dtypes()

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

Running that code produces the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 21
     19 fig_sb_multi, sb_multi_ax = plt.subplots()
     20 sb_multi_ax.clear()
---> 21 sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

File /lib/python3.11/site-packages/seaborn/relational.py:624, in scatterplot(data, x, y, hue, size, style, palette, hue_order, hue_norm, sizes, size_order, size_norm, markers, style_order, legend, ax, **kwargs)
    621 color = kwargs.pop("color", None)
    622 kwargs["color"] = _default_color(ax.scatter, hue, color, kwargs)
--> 624 p.plot(ax, kwargs)
    626 return ax

File /lib/python3.11/site-packages/seaborn/relational.py:458, in _ScatterPlotter.plot(self, ax, kws)
    456 if self.legend:
    457     attrs = {"hue": "color", "size": "s", "style": None}
--> 458     self.add_legend_data(ax, _scatter_legend_artist, kws, attrs)
    459     handles, _ = ax.get_legend_handles_labels()
    460     if handles:

File /lib/python3.11/site-packages/seaborn/_base.py:1246, in VectorPlotter.add_legend_data(self, ax, func, common_kws, attrs, semantic_kws)
   1244     attrs = {"hue": "color", "size": ["linewidth", "s"], "style": None}
   1245 for var, names in attrs.items():
-> 1246     self._update_legend_data(
   1247         update, var, verbosity, title, title_kws, names, semantic_kws.get(var),
   1248     )
   1250 legend_data = {}
   1251 legend_order = []

File /lib/python3.11/site-packages/seaborn/_base.py:1313, in VectorPlotter._update_legend_data(self, update, var, verbosity, title, title_kws, attr_names, other_props)
   1311         locator = mpl.ticker.MaxNLocator(nbins=brief_ticks)
   1312     limits = min(mapper.levels), max(mapper.levels)
-> 1313     levels, formatted_levels = locator_to_legend_entries(
   1314         locator, limits, self.plot_data[var].infer_objects().dtype
   1315     )
   1316 elif mapper.levels is None:
   1317     levels = formatted_levels = []

File /lib/python3.11/site-packages/seaborn/utils.py:698, in locator_to_legend_entries(locator, limits, dtype)
    696 def locator_to_legend_entries(locator, limits, dtype):
    697     """Return levels and formatted levels for brief numeric legends."""
--> 698     raw_levels = locator.tick_values(*limits).astype(dtype)
    700     # The locator can return ticks outside the limits, clip them here
    701     raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]

TypeError: Cannot interpret 'Float64Dtype()' as a data type

If you omit the size parameter or explicitly convert the types to float32, it works, e.g.:

df_sb_multi = df_sb_multi.astype({
    column: np.float32
    for column in df_sb_multi.drop(["X"], axis=1).columns
})

Does seaborn not support Float64 type?

mwaskom commented 1 year ago

To clarify, the important distinction is not between 32/64 bit precision, but between float/Float, where the latter is the pandas-specific dtype rather than a dtype that also exists in numpy.

alexpeters1208 commented 1 year ago

Some additional detail that we just found... This script does produce the above error:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 },
    {"X": 6, "Y": -0.279415, "Z": 0.96017, "R": 0.412779, "S": 0.837251 }
])
df_sb_multi = df_sb_multi.convert_dtypes()

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

but this one does not:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 }
])
df_sb_multi = df_sb_multi.convert_dtypes()

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

There is a hard threshold between plotting 6 and 7 points where this error starts.

mwaskom commented 1 year ago

That makes sense, the error is being raised from matplotlib ticker code that is producing the “brief” legend values.

Ultimately the pandas dtypes are an ongoing annoyance for seaborn. They’re still “experimental” in pandas and often cause issues in code written expecting numpy dtypes, which is most of matplotlib. Seaborn can try to cast data types and handle specific cases where they arise but there’s not a great general solution.

mofojed commented 1 year ago

@mwaskom would the recommendation be to use numpy data types in the meantime? Are you aware of any notes in matplotlib with a similar recommendation?

mwaskom commented 1 year ago

Yes, your issue is with the Float64 types produced by convert_dtypes. Using "regular" float64 should work fine.

lavanya-naresh commented 1 year ago

Some additional detail that we just found... This script does produce the above error:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 },
    {"X": 6, "Y": -0.279415, "Z": 0.96017, "R": 0.412779, "S": 0.837251 }
])
df_sb_multi = df_sb_multi.convert_dtypes()

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

but this one does not:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 }
])
df_sb_multi = df_sb_multi.convert_dtypes()

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

There is a hard threshold between plotting 6 and 7 points where this error starts.

I tried using astype function to convert the dataframe dtypes to float64 and it seems to work for more than the threshold of 7 rows. Version info: pandas 2.1.2, seaborn 0.13.0, matplotlib 3.7.1

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df_sb_multi = pd.DataFrame([
    {"X": 0, "Y": 0.0, "Z": 1.0, "R": 0.498653, "S": 2.582756 },
    {"X": 1, "Y": 0.841471, "Z": 0.540302, "R": 0.663367, "S": 3.193578 },
    {"X": 2, "Y": 0.909297, "Z": -0.416147, "R": 0.326006, "S": 0.241508 },
    {"X": 3, "Y": 0.14112, "Z": -0.989992, "R": 0.298382, "S": 40.054015 },
    {"X": 4, "Y": -0.756802, "Z": -0.653644, "R": 0.410429, "S": 33.189659 },
    {"X": 5, "Y": -0.958924, "Z": 0.283662, "R": 0.756501, "S": 41.980234 },
    {"X": 6, "Y": -0.279415, "Z": 0.96017, "R": 0.412779, "S": 0.837251 }
])
df_sb_multi = df_sb_multi.astype("float64")

fig_sb_multi, sb_multi_ax = plt.subplots()
sb_multi_ax.clear()
sns.scatterplot(df_sb_multi, x="X", y="R", size="S", ax=sb_multi_ax)

However, convert_dtypes() usage results in the error.

mwaskom commented 1 year ago

However, convert_dtypes() usage results in the error.

Right, convert_dtypes produces pandas types by default. You could use convert_floating=False too...

WillAyd commented 4 months ago

Ultimately the pandas dtypes are an ongoing annoyance for seaborn

I'm sorry that you have been experiencing this downstream, but I'm also not surprised. I have made a PDEP in pandas to align on a Logical Type System that I think could help, and would love any feedback on to improve this experience for the ecosystem:

https://github.com/pandas-dev/pandas/pull/58455

They’re still “experimental” in pandas and often cause issues in code written expecting numpy dtypes, which is most of matplotlib

This is another area I am hoping we can address via the PDEP process:

https://github.com/pandas-dev/pandas/pull/59125#discussion_r1657797729

These "experimental" types have existed since 2019 with little updates on the pandas side, so the experimental label is disingenuous at this point