pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
https://www.pymc-marketing.io/
Apache License 2.0
645 stars 172 forks source link

Prior support for "consumed" dims #841

Open wd60622 opened 1 month ago

wd60622 commented 1 month ago

Some distributions will likely not work because of the check for parent dims being a superset of the child. For instance,

from pymc_marketing.prior import Prior, UnsupportedShapeError

p = Prior("Dirichlet", a=[1, 1, 1], dims="probs")
try: 
    Prior("Categorical", p=p, dims="trial")
except UnsupportedShapeError as e: 
    print(e)

This could be relaxed based on some logic from the rv_op. i.e. pm.Categorical.rv_op has ndim_supp=0 and ndims_params=(1,) or by parsing the numpy-like signature

Previous (incorrect) examples

EDIT: The following are wrong because categorical has signature of (p)->() ```python from pymc_marketing.prior import Prior # Does work p = Prior("Dirichlet", a=[1, 2, 3], dims="prob") y = Prior("Categorical", p=p, dims=("trial", "prob")) coords = { "trial": [0, 1, 2, 3, 4], "prob": ["A", "B", "C"] } samples = y.sample_prior(coords=coords) ``` The transpose doesn't work as well ```python # Doesn't works z = Prior("Categorical", p=p, dims=("prob", "trial")) try: z.sample_prior(coords=coords) except ValueError as e: print(e) # But would with other distributions mu = Prior("Normal", dims="prob") x = Prior("Normal", mu=mu, dims=("prob", "trial")) samples = x.sample_prior(coords=coords) ```

Ref: https://github.com/pymc-devs/pymc/discussions/7416#discussioncomment-10084272

ricardoV94 commented 1 month ago

y = Prior("Categorical", p=p, dims=("trial", "prob")) this is not really meaningful dim-wise. the prob dimension from the p gets consumed when it goes into a Categorical. The signature is (p)->(), and not something like (p)->(p) (preserved in the output). So it doesn't make sense to have a dim of probs in the output.

This should make some sense. In the core case (without batch dims), you may have a vector of 100 probs that are consumed to generate a single scalar (a number between 0-99). You could ask to draw a batch of length 100 (dim probs) of these numbers, but that in general would be strange.

wd60622 commented 1 month ago

Yes, totally. Makes sense. Let me change up the scenarios then in the initial description. The "consumed" one fails because of internal checks of parent dim names. There'd have to be some logic change to support this based on that rv_op.signature

EDIT: The first message has been modified

wd60622 commented 1 month ago

Some scenarios that should work:

# Broadcasting support
p = Prior("Dirichlet", a=np.ones((2, 5)), dims=("geo", "prob"))
y = Prior("Categorical", p=p, dims=("geo", "trial"))

The information from rv_op could potentially warn if a dim name should be provided. For instance,

# Can provide a user warning
p = Prior("Dirichlet", a=[1, 1, 1])
ricardoV94 commented 1 month ago

Some scenarios that should work:

# Broadcasting support
p = Prior("Dirichlet", a=np.ones((2, 5)), dims=("geo", "prob"))
y = Prior("Categorical", p=p, dims=("geo", "trial"))

The information from rv_op could potentially warn if a dim name should be provided. For instance,

# Can provide a user warning
p = Prior("Dirichlet", a=[1, 1, 1])

Is that example with the dims "geo", "trial" in that order in purpose? You would need to transpose them for it to work since it must broadcast to the left.

To be clear when defining a Categorical distribution with p with dims (geo, probs), you could have dims (geo,), that is a single vector, or (trial, geo), which would be a matrix of observations, or anything with batch dims to the left of geo such as (year, .. , trial, geo)

wd60622 commented 1 month ago

Is that example with the dims "geo", "trial" in that order in purpose? You would need to transpose them for it to work since it must broadcast to the left.

Yup, the order is on purpose in order to support the broadcasting even under scenario of "consumed" dims. The rightmost dim(s?) is always consumed, right? At least in ndim_supp=0 & ndims_params with at least one greater than zero.

Have code below to get distributions that might need some additional investigation

from pytensor.tensor.random.basic import RandomVariable

import pymc as pm
from pymc.distributions.distribution import DistributionMeta

lookup = {}
for name in dir(pm):
    obj = getattr(pm, name)
    if isinstance(obj, DistributionMeta):
        lookup[name] = obj

def needs_investigation(
    rv_op,
) -> bool:
    """Non scalar to scalar"""
    return any(ndims != 0 for ndims in rv_op.ndims_params) or rv_op.ndim_supp != 0

rv_op_lookup = {}
for name, value in lookup.items():
    rv_op = value.rv_op

    # Another case to investigate
    if not isinstance(rv_op, RandomVariable):
        continue

    if not needs_investigation(rv_op):
        continue

    rv_op_lookup[name] = rv_op

Results in:

{'CAR': CARRV(name=car,ndim_supp=1,ndims_params=(1, 2, 0, 0),dtype=floatX,inplace=False),
 'Categorical': CategoricalRV(name=categorical,ndim_supp=0,ndims_params=(1,),dtype=int64,inplace=False),
 'Dirichlet': DirichletRV(name=dirichlet,ndim_supp=1,ndims_params=(1,),dtype=floatX,inplace=False),
 'ICAR': ICARRV(name=icar,ndim_supp=1,ndims_params=(2, 1, 1, 0, 0, 0),dtype=floatX,inplace=False),
 'Interpolated': InterpolatedRV(name=interpolated,ndim_supp=0,ndims_params=(1, 1, 1),dtype=floatX,inplace=False),
 'KroneckerNormal': KroneckerNormalRV(name=kroneckernormal,ndim_supp=1,ndims_params=(1, 0, 2),dtype=floatX,inplace=False),
 'MatrixNormal': MatrixNormalRV(name=matrixnormal,ndim_supp=2,ndims_params=(2, 2, 2),dtype=floatX,inplace=False),
 'Multinomial': MultinomialRV(name=multinomial,ndim_supp=1,ndims_params=(0, 1),dtype=int64,inplace=False),
 'MvNormal': MvNormalRV(name=multivariate_normal,ndim_supp=1,ndims_params=(1, 2),dtype=floatX,inplace=False),
 'MvStudentT': MvStudentTRV(name=multivariate_studentt,ndim_supp=1,ndims_params=(0, 1, 2),dtype=floatX,inplace=False),
 'StickBreakingWeights': StickBreakingWeightsRV(name=stick_breaking_weights,ndim_supp=1,ndims_params=(0, 0),dtype=floatX,inplace=False),
 'Wishart': WishartRV(name=wishart,ndim_supp=2,ndims_params=(0, 2),dtype=floatX,inplace=False)}

Currently implementation does support many Distributions already (any supported for PyMC-Marketing previously).

ricardoV94 commented 1 month ago

Yup, sounds like you're on it. Yeah the rightmost dims are the "core" ones.

One day when we implement dims in PyTensor we'll probably need extra kwargs for the user to tell us what dims should be used to core case. Like a Categorical may need to ve written like pt.random.categorical(p=p, dims=("geo", "trial"), p_dims="probs")

Because with named dims order loses any meaning and so we can't rely on it to disambiguate what's core from what is batch.

In this case we could probably introspect the dims of p and find out which one is there that is not "geo" or "trial". But like we don't require size to always be provided we may not want to require "dims" to be provided in which case we infer it from the dims of the parameters. This is all about future PyTensor API, not about the work here in pymc-marketing. Just using this as an excuse to think about it.

I don't know if you want to go that route here of asking which ones are the core dims or rely on the missing dims in the output to find out which one is the "probs". I'm wondering about matrix core inputs, like cov. In some cases order may matter, so even if you can reason about which dims are the core you may still need to know which one goes first. cov is not a problem because it has to be symmetric but that need not always be the case with matrix inputs.

The other tricky thing are multiple parametrizations. Like user may define MvNormal with the cholesky which is not an (m, m) matrix. Under the hood PyMC will convert it to a covariance and the Op signature is always correct. But here you're acting before PyMC does that conversion when you're trying to Dimshuffle the dims for then user.

Those edge cases is when we start desiring PyTensor to natively handle dims

wd60622 commented 1 month ago

Some things to note:

As for the problem of consumed dims under the Prior class API? Is there preference to p_dims as separate parameter or the dims being part of the already existing dims parameter. I see advantages of both ways

ricardoV94 commented 1 month ago

Duplicate dims are not allowed anywhere. Even if it's square they must have different names

ricardoV94 commented 1 month ago

p = Prior("Dirichlet", a=[1, 2, 3]) data = Prior("Categorical", p=p, dims="trial")

Right but that will fail when they are specified and/or are not aligned according to the non-dim semantics of PyMC, right?

ricardoV94 commented 1 month ago

For your pragmatic question, I don't know. I suggest trying what feels better and see how it goes.