pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

saving render_model() output to the desired file path #1831

Open znwang25 opened 1 month ago

znwang25 commented 1 month ago

When calling numpyro.render_model(model, filename=my_path) I was expecting the generated file to be saved in my_path. But it does not.

Reading the source code, it seems that it deliberately use filename.stem instead of user provided file path. Is this the intended behavior? Unable to save the file to the path user requested make the filename option useless.

def render_model(
    model,
    model_args=None,
    model_kwargs=None,
    filename=None,
    render_distributions=False,
    render_params=False,
):
    """
    Wrap all functions needed to automatically render a model.

    .. warning:: This utility does not support the
        :func:`~numpyro.contrib.control_flow.scan` primitive.
        If you want to render a time-series model, you can try
        to rewrite the code using Python for loop.

    :param model: Model to render.
    :param model_args: Positional arguments to pass to the model.
    :param model_kwargs: Keyword arguments to pass to the model.
    :param str filename: File to save rendered model in.
    :param bool render_distributions: Whether to include RV distribution annotations in the plot.
    :param bool render_params: Whether to show params in the plot.
    """
    relations = get_model_relations(
        model,
        model_args=model_args,
        model_kwargs=model_kwargs,
    )
    graph_spec = generate_graph_specification(relations, render_params=render_params)
    graph = render_graph(graph_spec, render_distributions=render_distributions)

    if filename is not None:
        filename = Path(filename)
        graph.render(
            filename.stem, view=False, cleanup=True, format=filename.suffix[1:]
        )  # remove leading period from suffix

    return graph
fehiepsi commented 1 month ago

Nice catch. I think we can use the original filename via the argument filename and get suffix with Path(filename).suffix. Do you want to submit the fix?

znwang25 commented 1 month ago

I think this will get it fixed.

    if filename is not None:
        filename = Path(filename)
        graph.render(
            filename.with_suffix(''), view=False, cleanup=True, format=filename.suffix[1:]
        )  # remove leading period from suffix
fehiepsi commented 1 month ago

Is it necessary to remove suffix in the filename?

znwang25 commented 1 month ago

Without removing it, filename=“Example.png” will generate “Example.png.png”. I did not do enough tests, but removing the suffix works for my use cases.

On Mon, Jul 8, 2024 at 1:28 PM Du Phan @.***> wrote:

Is it necessary to remove suffix in the filename?

— Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/numpyro/issues/1831#issuecomment-2215219098, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACYUSMIBRT4UMDYUIE4HIKLZLLY75AVCNFSM6AAAAABKRN767SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMJVGIYTSMBZHA . You are receiving this because you authored the thread.Message ID: @.***>

fehiepsi commented 1 month ago

Thanks! Sorry for the slow response. Do you want to submit the fix?