Open wd60622 opened 1 month ago
Any thoughts on this and what is useful to save off and load back in? @ColtAllen @louismagowan
The following is already supported with the MMM:
import mlflow
with mlflow.start_run():
configuration = model.idata.attrs
with open("configuration.json", "w") as f:
json.dump(configuration, f)
mlflow.log_artifact("configuration.json")
Just some of the values are already stored as strings because of the netCDF format.
Amazing! 🙌 I used PyMC (custom MMM) + MLFlow, and it was great to track experiments!
Very cool!
Couple of things that spring to mind:
We should submit a feature request to MLFlow to start supporting PyMC models natively.
It might be worth adding prefixes to your params and metrics
# Specify options for MCMC
SAMPLER_CONFIG = {
"draws": 1_000,
"tune": 500,
"chains": 6,
"target_accept": 0.9,
"progressbar": True,
"nuts_sampler": "numpyro",
"random_seed": SEED,
}
SAMPLER_CONFIG_LOGGING = { "samplerconfig" + key: val for key, val in SAMPLER_CONFIG.items() }
**There's lots of nice metrics, params and graphs that I find are useful to add**
- I have an idea for a PR on some evaluation and diagnostic metrics that I think could be logged (we do this atm)
- e.g. Something like this
# Initiate the MLflow run context
with mlflow.start_run(run_name=RUN_NAME) as run:
# Log git hash
git_commit_hash = get_git_revision_hash() #func to get hash of current notebook state
if git_commit_hash:
# Log the git commit hash as a tag
mlflow.set_tag("git_commit_id", git_commit_hash)
# Log the pre-processing / modelling decisions taken
mlflow.log_params(FEATURES)
mlflow.log_params(DIM_REDUCTION_CONFIG)
mlflow.log_params(SEASONALITY_CONFIG)
mlflow.log_params(TRAIN_TEST_CONFIG)
mlflow.log_params(SAMPLER_CONFIG_LOGGING)
# Log model metrics
mlflow.log_metrics(model_metrics)
mlflow.log_metrics(model_diagnostics)
# Log whatever artifacts you want
mlflow.log_figure(prior_plot, "graphs/prior_plot.png")
mlflow.log_figure(adstock_alphas_plot, "graphs/adstock_alphas_plot.png")
mlflow.log_figure(sat_lams_plot, "graphs/sat_lams_plot.png")
mlflow.log_figure(coeffs_intercept_plot, "graphs/coeffs_intercept_plot.png")
So yeah - lots of cool things we could do here! Very interested to hear about what you had in mind for pymc-marketing X mlflow 😁
I've got a couple PR ideas that might crossover with this stuff, so super happy to work on it with you too! (provided work doesn't get too busy haha)
The
model.idata.attrs
is a serialized format of the model which could be used or exposed better. For instance,MLflow