mwaskom / seaborn

Statistical data visualization in Python
https://seaborn.pydata.org
BSD 3-Clause "New" or "Revised" License
12.18k stars 1.89k forks source link

[Bug] Plotting categorical columns includes empty categories #3704

Closed Yazan-Sharaya closed 1 month ago

Yazan-Sharaya commented 1 month ago

A reproducible code example that demonstrates the problem

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

countries = ['US', 'Canada', 'Spain', 'US', 'Canada', 'Sweden', 'Jordan', 'Netherlands', 'US', 'Spain']
df = pd.DataFrame(countries, columns=['Countries'])
df['Countries'] = df['Countries'].astype('category')

filtered_df = df[df['Countries'] == 'US'].copy()

sns.countplot(filtered_df, x='Countries')
plt.show()

The output that you are seeing (an image of a plot, or the error message)

myplot

A clear explanation of why you think something is wrong

When plotting a categorical column, the resulting plot will contain all the categories even if they don't exist anymore. I couldn't find any direct information in the documentation about this. However, I found the following example at https://seaborn.pydata.org/tutorial/categorical.html#categorical-scatterplots. Specifically the part that contains

sns.catplot(data=tips.query("size != 3"), x="size", y="total_bill", native_scale=True)

Where the result had an empty column at size=3. Nonetheless, I'm not sure that this should be the case when creating a new dataframe without certain categories from the orginal one. I understand that this could be more of a pandas issue than seaborn's, but I felt like this should be mentioned or be more clearly documented. There's a couple of easy solutions to this problem currently

filtered_df['Countries'] = filtered_df['Countries'].astype('string')
# Or
filtered_df['Countries'] = filtered_df['Countries'].cat.remove_unused_categories()

The specific versions of seaborn and matplotlib that you are working with

mwaskom commented 1 month ago

This is fully intentional, you can dig up the original thread on the introduction of categorical support where it was discussed at length.

jhncls commented 1 month ago

@Yazan-Sharaya Sometimes, for consistency between plots, you want to see the unused categories. And sometimes, as in your case, you don't want to see them.

One way to only show the used categories way changes the dataframe:

filtered_df['Countries'] = filtered_df['Countries'].cat.remove_unused_categories()

An alternative way uses order= to restrict the plot to the desired categories:

sns.countplot(filtered_df, x='Countries', order=filtered_df['Countries'].unique())
mwaskom commented 1 month ago

Thanks @jhncls