pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
https://www.pymc-marketing.io/
Apache License 2.0
589 stars 138 forks source link

Hyperparameter tuning #753

Open shuvayan opened 3 weeks ago

shuvayan commented 3 weeks ago

Hello Bayesian Wizards, I recently worked on something for which the client was not sure about the adstock decay rates. They said it should be somewhere between 30 to 50% for each of the 5 channels. As you can understand , selecting and starting with a good set of priors was non trivial. Though I used optuna to figure out the best set of parameters based on my data, I would like to understand if there is any inherent way within pymc to achieve this?

wd60622 commented 3 weeks ago

Hi @shuvayan,

One workflow that might be helpful is to look at the adstock curves directly. This code with the new classes will plot an amount adstocked through time. I find this much more intuitive than thinking about the parameters directly.

For instance, here are the default priors:

import matplotlib.pyplot as plt

from pymc_marketing.mmm import GeometricAdstock

transformation = GeometricAdstock(l_max=15)

prior = transformation.sample_prior()
curve = transformation.sample_curve(prior, amount=10)

_, axes = transformation.plot_curve(curve)
ax = axes[0]
ax.set(
    title="Adstock of spend of 10 at time=0 with l_max=15",
    ylabel="Adstocked value",
)
plt.show()

adstock-of-10

Then changing the priors happens at initialization

priors = {
    # Alternative informed prior
    "alpha": {"dist": "Beta", "kwargs": {"alpha": 4, "beta": 1}},
}

transformation = GeometricAdstock(l_max=15, priors=priors)

These then look like this which might be too flat as 10 spend is likely distributed across the whole time window instead of decaying quickly like before. adstock-of-10-new-prior

NOTE: being completely uncertain about adstock while using GeometricAdstock would correspond with prior Beta(1, 1)

If the shapes supported from the GeometricAdstock is not enough, I'd checkout the other other classes:

As for an programmatic way, it might depend on the exact constraints or assumptions. That could likely be optimized using the pytensor graph that can be created from the adstock transformations. What was the formulation that you used?

Let me know if this seems helpful!

shuvayan commented 2 weeks ago

Hello @wd60622 , This is really informative and useful. I will try to understand in detail and get this implemented. BDW, can you please provide a little more details about the pytensor graph from the adstock transformations? below is the configuration being derived from within optuna based trials:

import optuna
import numpy as np

# Define the objective function to minimize divergences
def objective_minimize_divergences(trial):
    # Suggest values for all parameters
    prior_sigma = trial.suggest_float('prior_sigma', 0.1, 10.0)
    mu_values = [trial.suggest_float(f'mu_{ch}', 0.1, 10.0) for ch in spend_cols]
    alpha = trial.suggest_float('alpha', 0.1, 5.0)
    lam_alpha = trial.suggest_float('lam_alpha', 0.1, 10.0)
    lam_beta = trial.suggest_float('lam_beta', 0.1, 10.0)
    likelihood_sigma = trial.suggest_float('likelihood_sigma', 0.1, 10.0)
    gamma_control_mu = trial.suggest_float('gamma_control_mu', -5.0, 5.0)
    gamma_control_sigma = trial.suggest_float('gamma_control_sigma', 0.1, 10.0)
    gamma_fourier_mu = trial.suggest_float('gamma_fourier_mu', -5.0, 5.0)
    gamma_fourier_b = trial.suggest_float('gamma_fourier_b', 0.1, 10.0)
    intercept_tvp_m = trial.suggest_int('intercept_tvp_m', 10, 50)
    intercept_tvp_ls_mu = trial.suggest_float('intercept_tvp_ls_mu', 50, 200)
    intercept_tvp_ls_sigma = trial.suggest_float('intercept_tvp_ls_sigma', 0.1, 5.0)
    intercept_tvp_eta_lam = trial.suggest_float('intercept_tvp_eta_lam', 0.1, 10.0)

    # Define the model config with the suggested parameters
    my_model_config = {
        'intercept': {'dist': 'HalfNormal', 'kwargs': {'sigma': 0.05}},
        'beta_channel': {
            'dist': 'LogNormal',
            "kwargs": {"mu": np.array(mu_values), "sigma": prior_sigma}
        },
        'alpha': {'dist': 'Beta', 'kwargs': {'alpha': alpha, 'beta': 3}},
        'lam': {'dist': 'Gamma', 'kwargs': {'alpha': lam_alpha, 'beta': lam_beta}},
        'likelihood': {'dist': 'Normal', 'kwargs': {'sigma': {'dist': 'HalfNormal', 'kwargs': {'sigma': likelihood_sigma}}}},
        'gamma_control': {'dist': 'Normal', 'kwargs': {'mu': gamma_control_mu, 'sigma': gamma_control_sigma}},
        'gamma_fourier': {'dist': 'Laplace', 'kwargs': {'mu': gamma_fourier_mu, 'b': gamma_fourier_b}},
        'intercept_tvp_kwargs': {
            'm': intercept_tvp_m,
            'L': None,
            'eta_lam': intercept_tvp_eta_lam,
            'ls_mu': intercept_tvp_ls_mu,
            'ls_sigma': intercept_tvp_ls_sigma,
            'cov_func': None
        }
    }

    # Define the model
    ollies_mmm = DelayedSaturatedMMM(
        model_config=my_model_config,
        target_column='R',
        date_column='Date',
        validate_data=True,
        channel_columns=spend_cols,
        adstock_max_lag=10,
        yearly_seasonality=2,
        time_varying_intercept=False
    )

    try:
        # Fit the model on the dataset
        idata = ol_mmm.fit(X, y, target_accept=0.95, chains=4, random_seed=43)
        # Check for divergences
        divergences = np.array(idata.sample_stats.diverging.values).sum()
        return divergences
    except Exception as e:
        return float('inf')

# Create and run the Optuna study for minimizing divergences
study_divergences = optuna.create_study(direction='minimize')
study_divergences.optimize(objective_minimize_divergences, n_trials=5)

here the optuna study is designed to minimize the divergences within the trials.

wd60622 commented 1 day ago

I would think that there would be a more direct way of optimizing the adstock parameter directly then using this loop.

Also, not sure if divergences of the model would be a sign of meeting the client's requirements.

I would toy around with the different AdstockTransformation subclasses that we provide and see if the general shapes that they provide meet their expectations of ad decays based on the channel.

Some of the curves have distinct shapes that might be more intuitive to the client than the parameters themselves. I also might suggest creating various stats from the adstock curves, like the half-life of the adstock curve or cumulative adstocked value over time. Something else might resonate better.