pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
https://www.pymc-marketing.io/
Apache License 2.0
647 stars 173 forks source link

Is there a way to get the plots for adstock and saturation functions from the inference data object #699

Open shuvayan opened 3 months ago

shuvayan commented 3 months ago

Hello Bayesians,

I came across the below image :

Screenshot 2024-05-24 at 8 46 18 AM

from the pymc labs blog - https://www.pymc-labs.com/blog-posts/modelling-changes-marketing-effectiveness-over-time/ And would like to know if there are ways to get these for each media channel along with separate plots for adstock. It helps to explain the model to clients and would be really helpful.

Please do let me know.

juanitorduz commented 3 months ago

Hi @shuvayan! These exact plots we do not have. Nevertheless, we do have many alternatives (see https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html) For example we have the local contribution plots (to understand weekly spending):

image

You can also pass a fit as in https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_budget_allocation_example.html

image

We also have global ones to understand the complete effect (historical)

image

I hope this helps :)

shuvayan commented 3 months ago

Thank you for the reply @juanitorduz . Can I see the saturation point on these plots? I would like to be able to say to the client : Your adspend effect decays by 'x%' after 'n' weeks and reaches the saturation point after 'm' weeks. Is is possible from the plots and inference data returned by the model?

juanitorduz commented 3 months ago

For the local plots, I think it is possible since we do compute these inflection points, see https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/mmm/base.py#L964 . You can use the function highlighted to get the points yourself . I will check why they are not appearing in the plot.

shuvayan commented 2 months ago

I am currently using the below code to get the response curves, implemented here using mock data:

"""
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import arviz as az

# Define adstock and saturation functions
def adstock(x, alpha):
    y = np.zeros_like(x)
    for i in range(1, len(x)):
        y[i] = x[i] + alpha * y[i - 1]
    return y

def saturation_function(x, lam, beta):
    return beta / (1 + np.exp(-lam * x))

# Extract parameters from the inference data object
def get_params(idata, channel):
    alpha = idata.posterior['alpha'].mean(dim=['chain', 'draw']).sel(channel=channel).values
    lam = idata.posterior['lam'].mean(dim=['chain', 'draw']).sel(channel=channel).values
    beta = idata.posterior['beta_channel'].mean(dim=['chain', 'draw']).sel(channel=channel).values
    return alpha, lam, beta

# Assume df is your dataframe containing the spend data
# Replace with your actual dataframe loading code
# df = pd.read_csv('your_dataframe.csv')

# Example dataframe structure (replace with your actual data)
import pandas as pd
data = {
    'Spend_META': [100, 150, 200, 250, 300, 350],
    'Spend_TikTok': [80, 120, 160, 200, 240, 280],
    'Spend_Search': [70, 110, 150, 190, 230, 270],
    'Spend_OLV': [60, 100, 140, 180, 220, 260],
    'Spend_CTV': [50, 90, 130, 170, 210, 250],
    'Spend_Display': [40, 80, 120, 160, 200, 240]
}
df = pd.DataFrame(data)

# Create spend dictionary from dataframe
spend = {col: df[col].values for col in df.columns}

# Define your idata loading method
# idata = az.from_netcdf('path_to_idata_file.nc')  # example loading, adjust to your actual loading method

# Extract parameters for each channel
channels = ['Spend_META', 'Spend_TikTok', 'Spend_Search', 'Spend_OLV', 'Spend_CTV', 'Spend_Display']
params = {channel: get_params(idata, channel) for channel in channels}

# Step size and x range for curves
step_size = 0.05
xx = np.arange(0, max(max(spend[channel]) for channel in channels) * 1.1, step_size)

# Define curve functions and apply adstock transformation for each channel
curves = {}
for channel in channels:
    alpha, lam, beta = params[channel]
    adstocked_xx = adstock(xx, alpha)
    curve_fn = partial(saturation_function, lam=lam, beta=beta)
    curves[channel] = curve_fn(adstocked_xx)

# Plot functions
def plot_actual_curves(ax: plt.Axes, linestyle: str | None = None) -> plt.Axes:
    for i, (channel, curve) in enumerate(curves.items()):
        ax.plot(xx, curve, label=channel, color=f"C{i}", linestyle=linestyle)
    return ax

def plot_reference(ax: plt.Axes) -> plt.Axes:
    ax.plot(xx, xx, label="y=x", color="black", linestyle="--")
    return ax

ax = plt.gca()
plot_actual_curves(ax)
plot_reference(ax)
ax.set(
    xlabel="channel spend",
    ylabel="channel contribution",
    title="Actual Saturation Curves (Unobserved)",
)
ax.legend()
plt.show()
"""

Is this the correct way. I am looking to generate something like the below graph: Screenshot 2024-06-23 at 9 31 12 AM

wd60622 commented 2 months ago

One thing that you might not want to do is adstock the linspace / arange because the adstock function takes into account previous spends. To the function, this is viewed as increasing spend rather than constant spend at a given value. You might want to just do the saturation transformation instead