Closed mdancho84 closed 9 months ago
Plot correlation funnel
import pandas as pd
import numpy as np
from plotnine import ggplot, aes, geom_vline, geom_point, geom_text, labs, theme_minimal, theme, element_text
def plot_correlation_funnel(data, interactive=False, limits=(-1, 1), alpha=1):
if not isinstance(data, pd.DataFrame):
raise ValueError("plot_correlation_funnel(): Object is not of class `pd.DataFrame`.")
if interactive:
data['label_text'] = data.apply(lambda row: f"{row['feature']}\nCorrelation: {row['correlation']:.3f}", axis=1)
p = (
ggplot(data)
+ aes(x='correlation', y='feature', text='label_text')
+ geom_vline(xintercept=0, linetype='dashed', color='red')
+ geom_point(color='#2c3e50', alpha=alpha)
+ labs(title='Correlation Funnel')
+ theme_minimal()
)
p = p + theme(axis_text_x=element_text(size=12))
return p
else:
p = (
ggplot(data)
+ aes(x='correlation', y='feature', label='feature')
+ geom_vline(xintercept=0, linetype='dashed', color='red')
+ geom_point(color='#2c3e50', alpha=alpha)
+ geom_text(size=12, color='#2c3e50')
+ labs(title='Correlation Funnel')
+ theme_minimal()
)
p = p + theme(axis_text_x=element_text(size=12))
return p
# Example usage
#data = pd.read_csv('your_data.csv') # Replace 'your_data.csv' with your dataset file
#interactive_plot = plot_correlation_funnel(data, interactive=True)
#print(interactive_plot)
# For a non-interactive plot
# plot_correlation_funnel(data, interactive=False).draw()
Plotly
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
def plot_correlation_funnel(data, interactive=False, limits=(-1, 1), alpha=1):
if not isinstance(data, pd.DataFrame):
raise ValueError("plot_correlation_funnel(): Object is not of class `pd.DataFrame`.")
if interactive:
data['label_text'] = data.apply(lambda row: f"{row['feature']}\nCorrelation: {row['correlation']:.3f}", axis=1)
fig = px.scatter(
data,
x='correlation',
y='feature',
text='label_text',
range_x=limits,
title='Correlation Funnel'
)
fig.update_traces(marker=dict(color='#2c3e50', opacity=alpha), selector=dict(mode='markers'))
fig.update_layout(shapes=[dict(type='line', x0=0, x1=0, y0=0, y1=1, yref='paper', line=dict(color='red', dash='dash'))])
fig.update_xaxes(title_text="Correlation")
fig.update_yaxes(title_text="Feature")
fig.update_layout(showlegend=False)
return fig
else:
fig, ax = plt.subplots()
ax.scatter(data['correlation'], data['feature'], c='#2c3e50', alpha=alpha)
for i, row in data.iterrows():
ax.text(row['correlation'], row['feature'], size=12, color='#2c3e50')
ax.axvline(x=0, linestyle='--', color='red')
ax.set_xlim(limits)
ax.set_xlabel('Correlation')
ax.set_ylabel('Feature')
ax.set_title('Correlation Funnel')
return plt.show()
# Example usage
#data = pd.read_csv('your_data.csv') # Replace 'your_data.csv' with your dataset file
#interactive_plot = plot_correlation_funnel(data, interactive=True)
# To display the interactive plot, you can use interactive_plot.show()
# For a non-interactive plot
# plot_correlation_funnel(data, interactive=False)
I have a basic example working.
# NON-TIMESERIES EXAMPLE ----
import pandas as pd
import numpy as np
import pytimetk as tk
# Set a random seed for reproducibility
np.random.seed(0)
# Define the number of rows for your DataFrame
num_rows = 200
# Create fake data for the columns
data = {
'Age': np.random.randint(18, 65, size=num_rows),
'Gender': np.random.choice(['Male', 'Female'], size=num_rows),
'Marital_Status': np.random.choice(['Single', 'Married', 'Divorced'], size=num_rows),
'City': np.random.choice(['New York', 'Los Angeles', 'Chicago', 'Houston', 'Miami'], size=num_rows),
'Years_Playing': np.random.randint(0, 30, size=num_rows),
'Average_Income': np.random.randint(20000, 100000, size=num_rows),
'Member_Status': np.random.choice(['Bronze', 'Silver', 'Gold', 'Platinum'], size=num_rows),
'Number_Children': np.random.randint(0, 5, size=num_rows),
'Own_House_Flag': np.random.choice([True, False], size=num_rows),
'Own_Car_Count': np.random.randint(0, 3, size=num_rows),
'PersonId': range(1, num_rows + 1), # Add a PersonId column as a row count
'Client': np.random.choice(['A', 'B'], size=num_rows) # Add a Client column with random values 'A' or 'B'
}
# Create a DataFrame
df = pd.DataFrame(data)
# Binarize the data
df_binarized = df.binarize(n_bins=4, thresh_infreq=0.01, name_infreq="-OTHER", one_hot=True)
df_binarized.glimpse()
``` {python}
df_correlated = df_binarized.correlate(target='Member_Status__Platinum')
df_correlated.head(10)
```
``` {python}
# Interactive
df_correlated.plot_correlation_funnel(
interactive=True,
height=600
)
```
``` {python}
# Static
df_correlated.plot_correlation_funnel(
interactive=False,
height = 900
)
## Plotly
![image](https://github.com/business-science/pytimetk/assets/13734662/756f4b10-962f-4c79-862e-ccb4d56d62af)
## Plotnine
For some reason the arrows are showing very think on the `plotnine` `adjust_text` integration.
![image](https://github.com/business-science/pytimetk/assets/13734662/d33c8a6a-3155-4333-ba48-9588c8edfa52)
Starter code from Jared and Alex