mwaskom / seaborn

Statistical data visualization in Python
https://seaborn.pydata.org
BSD 3-Clause "New" or "Revised" License
12.18k stars 1.89k forks source link

sns.regplot regression line fails with large values #3700

Open Gabriel-Kissin opened 1 month ago

Gabriel-Kissin commented 1 month ago

With large values, the regression line can be done incorrectly:

image

The first three lines are good, the last three aren't. (Note also the strange shape of the green shaded region which is the same issue).

This is not strictly a seaborn issue - see https://github.com/statsmodels/statsmodels/issues/9258 where there is some further information about this. However, as other OLS methods implemented by statsmodels (QR) produce an accurate fit in these cases, as does sklearn (dotted black line), I thought it might be worth posting here. Perhaps worth considering whether a more robust algorithm which doesn't lead to these issues can be used to simplify the line of best fit visualisation process.

The code used to generate the above plot is

import numpy as np
import statsmodels.api
import sklearn.linear_model
import matplotlib.pyplot as plt
import seaborn as sns

x_base = np.linspace(4e13, 10e13, 10)
y = np.linspace(1, 0, 10)

for i in range(6):

    x = x_base + (i*3e13)

    # solve using statsmodels
    stats_ols = statsmodels.regression.linear_model.OLS(
        endog=y, exog=statsmodels.api.add_constant(x))
    stats_ols_fitted = stats_ols.fit()                      # uses method = "pinv" by default
    # stats_ols_fitted = stats_ols.fit(method = "qr")         # fits correctly

    # solve & predict using sklearn
    sklearn_ols = sklearn.linear_model.LinearRegression()
    sklearn_ols.fit(x.reshape((-1,1)), y)
    x_sklearn = np.linspace(x.min(), x.max())
    y_sklearn = sklearn_ols.predict(x_sklearn.reshape((-1,1)))

    # compose informative legend label for each set of data/LR model
    label ='statsmodels OLS: $r^2=' + str(np.round(stats_ols_fitted._results.rsquared, 3)) + '$' 
    label += '\nStatsmodels params: ' + ', '.join(['{:0.3}'.format(param) for param in stats_ols_fitted._results.params]) 
    label += '\nSklearn params: ' + ', '.join(['{:0.3}'.format(param) for param in [sklearn_ols.intercept_] + list(sklearn_ols.coef_)]) 

    # plot using seaborn
    sns.regplot(x=x, y=y, label=label, ax=plt.gca())

    # plot the LR fits (sklearn)
    plt.plot(x_sklearn, y_sklearn,
             label=('sklearn LinearRegression' if i in [2,5] else ''),
             ls=':', lw=1.5, c='k')

plt.legend(fontsize='small', loc='center left', bbox_to_anchor=(1, 0.5), ncols=2)
plt.show()
Gabriel-Kissin commented 1 month ago

An easy fix would be to scale the data before fitting the regression.

mwaskom commented 3 weeks ago

I'm inclined to agree with the statsmodels folks that there is no one obviously best algorithm to use here, and also that scaling the data should be a very easy workaround to avoid this if you run into it.