h2oai / h2o-3

H2O is an Open Source, Distributed, Fast & Scalable Machine Learning Platform: Deep Learning, Gradient Boosting (GBM) & XGBoost, Random Forest, Generalized Linear Modeling (GLM with Elastic Net), K-Means, PCA, Generalized Additive Models (GAM), RuleFit, Support Vector Machine (SVM), Stacked Ensembles, Automatic Machine Learning (AutoML), etc.
http://h2o.ai
Apache License 2.0
6.93k stars 2k forks source link

Enable saving explain objects #15585

Open tomasfryda opened 1 year ago

tomasfryda commented 1 year ago

Prototype:

from io import StringIO, BytesIO
def render(obj, file, rows=100, fig_format="svg"):
    if fig_format not in ["png", "svg"]:
        raise Exception("fig_format can be just 'png' or 'svg'.")
    if isinstance(obj, dict):
        for ex in obj.values():
            render(ex, file, rows, fig_format=fig_format)
    elif isinstance(obj, tuple) or isinstance(obj, list):
        for ex in obj:
            render(ex, file, rows, fig_format=fig_format)
    else:
        if isinstance(obj, h2o.explanation.Header):
            file.write(f"<h{obj.level}>{obj.content}</h{obj.level}>\n")
        elif isinstance(obj, h2o.explanation.Description):
            file.write(f"<blockquote>{obj.content}</blockquote>\n")
        elif isinstance(obj, h2o.two_dim_table.H2OTwoDimTable):
            file.write(obj._as_display(rows=rows, prefer_pandas=False)._str_html_())
        elif isinstance(obj, h2o.model.confusion_matrix.ConfusionMatrix):
            render(obj.table, file, rows, fig_format=fig_format)
        elif hasattr(obj, "_str_html_"):
            file.write(obj._str_html_())
        elif h2o.plot.is_decorated_plot_result(obj):
            with BytesIO() as bfile:
                obj.figure().savefig(bfile, format=fig_format)
                if fig_format == "png":
                    import base64
                    b64=base64.b64encode(bfile.getvalue())
                    file.write(f"<img src=\"data:image/png;base64, {b64.decode()}\" />")
                elif fig_format == "svg":
                    file.write(bfile.getvalue().decode())
        else:
            print(f"Unsupported {obj.__class__}")

def save_explain(explanation, file, max_rows=100, max_cols=200, fig_format="svg"):
    with h2o.display.local_context(rows=max_rows, cols=max_cols, use_pandas=False):
        with open(file, "w") as f:
            f.write("""
            <html>
              <head>
                <title>Explanation</title>
                <style>
                caption {
                  white-space: nowrap;
                  caption-side: top;
                  text-align: left;
                  margin: 0;
                  font-size: larger;
                }
                th {
                  border-bottom: 1px solid rgba(0, 0, 0, 0.87);
                }
                th,
                td {
                  text-align: right;
                  padding: 0.5ex 1em;
                }
                tr:nth-child(even) {
                  background: rgb(245, 245, 245);
                }
                tr:hover {
                  background: rgb(225, 245, 254);
                }
                blockquote {
                  display: block;
                  margin: 1em 2em;
                  padding: 0 1em;
                  border-left: 5px solid rgb(224, 224, 224);
                  margin-block-start: 1em;
                  margin-block-end: 1em;
                  margin-inline-start: 40px;
                  margin-inline-end: 40px;
                }
                h1, h2, h3, h4, h5, h6 {
                  font-family: sans-serif;
                }
                </style>
              </head>
              <body>""")
            render(explanation, f, rows=max_rows, fig_format=fig_format)
            f.write("</body>\n</html>")

# How to use it
save_explain(exp_regression, "out_png.html", fig_format="png")
save_explain(exp_regression, "out_svg.html", fig_format="svg")
save_explain(exp_binary, "out2_png.html", fig_format="png")
save_explain(exp_binary, "out2_svg.html", fig_format="svg")