google / lightweight_mmm

LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM) library that allows users to easily train MMMs and obtain channel attribution information.
Apache License 2.0
829 stars 172 forks source link

Anyway to save mmm.print_summary()? #274

Open cincysam6 opened 7 months ago

cincysam6 commented 7 months ago

Does anyone know if there is a simple way to save out the mmm.print_summary() results to either an image or csv/text file? I would love to capture that part when my model is complete in some type of file.

becksimpson commented 7 months ago

@cincysam6 A lot of the time I just give it an eyeball. As it's using a pretty encapsulated function from pyro, I think the simplest, beyond screen shoting is you can capture stdout, and then process as you like. This is a potential example that stores the print summary to a pandas Dataframe, which can be saved.. You might have to alter this depending on what version pyro you have, I don't know if they've been consistent with their table formatting.

from contextlib import redirect_stdout
import io
import re

# Capture Output to stdout by print summary
f = io.StringIO()
with redirect_stdout(f):
s = f.getvalue()

# Custom processing 
# Due to noise in output I couldn't get reading directly from string buffer into pandas read csv to function)
# Row separators are '\n' and Col separators are whitespace '\\s+'
rows = []
cols = re.split('\\s+', s.split('\n')[1])[1:]
print(s.split('\n')[-2:][0]) # Divergences
for row in [ln for i, ln in enumerate(s.split('\n')) if i > 1]:
    split_row = re.split('\\s+', row)[1:]
    # For 3D parameters (gamma_seasonality)
    if len(split_row) == 7:
        rows += [['', *split_row]]
    elif len(split_row) == 8:
        rows += [split_row]

df = pd.DataFrame(rows, columns=['Param', *cols] )