google / lightweight_mmm

LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM) library that allows users to easily train MMMs and obtain channel attribution information.
https://lightweight-mmm.readthedocs.io/en/latest/index.html
Apache License 2.0
829 stars 172 forks source link

Anyway to save mmm.print_summary()? #274

Open cincysam6 opened 7 months ago

cincysam6 commented 7 months ago

Does anyone know if there is a simple way to save out the mmm.print_summary() results to either an image or csv/text file? I would love to capture that part when my model is complete in some type of file.

becksimpson commented 7 months ago

@cincysam6 A lot of the time I just give it an eyeball. As it's using a pretty encapsulated function from pyro, I think the simplest, beyond screen shoting is you can capture stdout, and then process as you like. This is a potential example that stores the print summary to a pandas Dataframe, which can be saved.. You might have to alter this depending on what version pyro you have, I don't know if they've been consistent with their table formatting.

from contextlib import redirect_stdout
import io
import re

# Capture Output to stdout by print summary
f = io.StringIO()
with redirect_stdout(f):
    mmm.print_summary()
s = f.getvalue()

# Custom processing 
# Due to noise in output I couldn't get reading directly from string buffer into pandas read csv to function)
# Row separators are '\n' and Col separators are whitespace '\\s+'
rows = []
cols = re.split('\\s+', s.split('\n')[1])[1:]
print(s.split('\n')[-2:][0]) # Divergences
for row in [ln for i, ln in enumerate(s.split('\n')) if i > 1]:
    split_row = re.split('\\s+', row)[1:]
    # For 3D parameters (gamma_seasonality)
    if len(split_row) == 7:
        rows += [['', *split_row]]
    elif len(split_row) == 8:
        rows += [split_row]

df = pd.DataFrame(rows, columns=['Param', *cols] )