pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
https://www.pymc-marketing.io/
Apache License 2.0
587 stars 138 forks source link

MMM mixins as single sklearn transformer #407

Open wd60622 opened 8 months ago

wd60622 commented 8 months ago

The sklearn transformers that require state (MaxAbsScaler and StandardScaler) are currently part of the mixins which make it a bit difficult to support new data.

So the mixins will have to be change in order to help the new data use case.

Instead of just incorporating into the current mixins, I suggest relying on sklearn to handle the transformation pass directly as I think it is much clearer about what is happening to the data. The current implementation that relies on mixins and searching methods based on tags which is confusing and, imo, hard to maintain.

Sklearn provides fit and transform API already which is easy to adapt for the current flow. For instance, the X transform

from sklearn.preprocessing import MaxAbsScaler, StandardScaler
from sklearn.compose import ColumnTransformer

channel_cols = ["Facebook", "Google"]
control_cols = ["event_1", "event_2"]
date_col = "date_week"

class FourierTransformer(BaseEstimator, TransformerMixin):
    # To create based on fourier transformations
    def fit(self, X, y=None) -> Self: 
        ...
    def transform(self, X) -> pd.DataFrame: 
        ...

# This pipeline can be built based on the channel_cols, control_cols, and date_col provided
pipeline = ColumnTransformer([
    ("channel_scaling", MaxAbsScaler(), channel_cols),
    ("control_scaling", StandardScaler(), control_cols),
    ("date_modes", FourierTransformer(n_order=n_order), date_col)
]).set_output(transform="pandas")

# On fit during self.preprocess
pipeline.fit_transform(X_train)

# Should be used on new data. i.e. _data_setter for predict_posterior, etc
pipeline.transform(X_test)

After the fit, pipeline.named_transformers_ in combination with get_feature_names_out() can retrieve the column names required for coordinates

Any thoughts on this? Using sklearn should provide similar feedback to users about missing columns that is currently part of the mixin logic

wd60622 commented 8 months ago

This would also touch on #402

ricardoV94 commented 8 months ago

See also discussion to get rid of Sklearn transformers altogether: https://github.com/pymc-labs/pymc-marketing/pull/386#issuecomment-1771891034