pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.53k stars 1.98k forks source link

Show dims in model textual representation #7154

Closed ricardoV94 closed 4 months ago

ricardoV94 commented 5 months ago

Description

Currently, there is no information about dimensionality of variables:

import pymc as pm
from pymc.printing import str_for_model

with pm.Model(coords={"trial": range(10)}) as m:
    x = pm.Normal("x")
    y = pm.Normal("y", dims=["trial"])

print(str_for_model(m))
x ~ Normal(0, 1)
y ~ Normal(0, 1)

Would be nice if output was something like the following:

1.

x ~ Normal(0, 1)
y ~ Normal(0, 1, dims=[trial])

2.

x ~ Normal(0, 1)
y ~ Normal(0, 1) [trial]

3.

x ~ Normal(0, 1)
y[trial] ~ Normal(0, 1)

Do you like any of these. Better suggestions?

ricardoV94 commented 4 months ago

Here is one manually curated model idea:

                intercept_sd ~ TruncatedNormal(0.2, 0.5, 0, inf)                                    :()
                   intercept ~ ZeroSumNormal(intercept_sd))                                         :(description,)
            log_price_effect ~ LogNormal(0, 1)                                                      :()
               volume_effect ~ Normal(0, 1)                                                         :()
         price_volume_effect ~ Normal(0, 1)                                                         :()
manufacturer_importance::raw ~ Beta(1, 10)                                                          :(manufacturer*,)
     segment_importance::raw ~ Beta(1, 10)                                                          :(manufacturer*, segment*)
           no_choice_utility ~ Normal(4, 0.75)                                                      :()
                 market_size ~ TruncatedGamma(...)                                                  :()
               total_sales_k ~ Exponential(1e+03)                                                   :()
               sales_units_k ~ InverseGamma(6, 2e+04)                                               :()

                base_utility ~ f(price_volume_effect, volume_effect, log_price_effect, intercept)   :(date, item)
                     utility ~ f(base_utility)                                                      :(date, item)
     manufacturer_importance ~ f(manufacturer_importance::raw)                                      :(manufacturer,)
          segment_importance ~ f(segment_importance::raw, manufacturer_importance::raw)             :(manufacturer, segment)
                market_share ~ f(manufacturer_importance::raw, segment_importance::raw, utility)    :(date, item)
            total_sales_mean ~ f(market_size, no_choice_utility, utility),                          :(date)

                 total_sales ~ NegativeBinomial(total_sales_k, f(total_sales_mean, total_sales_k))  :(date)
                 sales_units ~ DirichletMN(f(total_sales), f(market_share))                         :(date, item)

The three sections are: free_RVs, Deterministics, observed_RVs