pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
https://www.pymc-marketing.io/
Apache License 2.0
705 stars 198 forks source link

#990 - Plot FourierBase along date rather than index #1068

Closed Ishaanjolly closed 1 month ago

Ishaanjolly commented 1 month ago
Enhancement: Plot FourierBase along date rather than index

Description

I added the following:

def sample_curve(
        self,
        parameters: az.InferenceData | xr.Dataset,
        use_dates: bool = False,
        start_date: datetime.datetime | None = None,
    ) -> xr.DataArray:
        """Create full period of the Fourier seasonality.

        Parameters
        ----------
        parameters : az.InferenceData | xr.Dataset
            Inference data or dataset containing the Fourier parameters.
            Can be posterior or prior.
        use_dates : bool, optional
            If True, use datetime coordinates for the x-axis. Defaults to False.
        start_date : datetime.datetime, optional
            Starting date for the Fourier curve. If not provided and use_dates is True,
            it will be derived from the current year or month. Defaults to None.

        Returns
        -------
        xr.DataArray
            Full period of the Fourier seasonality.

        """
        # Determine the full period
        full_period = np.arange(int(self.days_in_period) + 1)

        coords = {}
        if use_dates:
            if start_date is None:
                # Derive start_date based on the type of Fourier seasonality
                today = datetime.datetime.now()
                if isinstance(self, YearlyFourier):
                    start_date = datetime.datetime(year=today.year, month=1, day=1)
                elif isinstance(self, MonthlyFourier):
                    start_date = datetime.datetime(
                        year=today.year, month=today.month, day=1
                    )
                else:
                    raise ValueError("Unknown Fourier type for deriving start_date")

            # Create a date range
            date_range = pd.date_range(
                start=start_date,
                periods=int(self.days_in_period) + 1,
                freq="D",
            )
            coords["date"] = date_range.to_numpy()
            dayofyear = date_range.dayofyear.to_numpy()

        else:
            coords["day"] = full_period
            dayofyear = full_period

        # Include other coordinates from the parameters
        for key, values in parameters[self.variable_name].coords.items():
            if key in {"chain", "draw", self.prefix}:
                continue
            coords[key] = values.to_numpy()

        with pm.Model(coords=coords):
            name = f"{self.prefix}_trend"
            pm.Deterministic(
                name,
                self.apply(dayofyear=dayofyear),
                dims=tuple(coords.keys()),
            )

            return pm.sample_posterior_predictive(
                parameters,
                var_names=[name],
            ).posterior_predictive[name]

and following within each plot_* function as I was not able to define new attributes:

if "date" in curve.coords:
      x_coord_name = "date"
elif "day" in curve.coords:
    x_coord_name = "day"
else:
    raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

Related Issue

Checklist

Modules affected

Type of change


📚 Documentation preview 📚: https://pymc-marketing--1068.org.readthedocs.build/en/1068/

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 55.26316% with 17 lines in your changes missing coverage. Please review.

Project coverage is 95.23%. Comparing base (d05c2d8) to head (c1be3a5). Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pymc_marketing/mmm/fourier.py 55.26% 17 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1068 +/- ## ========================================== - Coverage 95.85% 95.23% -0.62% ========================================== Files 39 39 Lines 3934 3969 +35 ========================================== + Hits 3771 3780 +9 - Misses 163 189 +26 ``` | [Flag](https://app.codecov.io/gh/pymc-labs/pymc-marketing/pull/1068/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-labs) | Coverage Δ | | |---|---|---| | [](https://app.codecov.io/gh/pymc-labs/pymc-marketing/pull/1068/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-labs) | `95.23% <55.26%> (-0.62%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-labs#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.