stared / livelossplot

Live training loss plot in Jupyter Notebook for Keras, PyTorch and others
https://p.migdal.pl/livelossplot
MIT License
1.29k stars 142 forks source link

Option to skip `plt.show()` so that plots can later be modified #73

Closed salu133445 closed 4 years ago

salu133445 commented 5 years ago

Hi, thanks for the really nice package. I would like to suggest to add an option to skip plt.show() in draw_plots() so that the plots can later be modified. That is something like

By skipping plt.show(), users can then, for example, change the styles, modify the labels or add annotations, and call plt.show() afterward. For instance,

liveloss.draw(show=False)
plt.xlabel('step')
plt.show()

Thanks!

stared commented 5 years ago

@salu133445 I have a long-planned rewrite of the plotting so that it would be easier to customize plots.

In the meantime, if you create a consistent way to pass show, I would be happy to accept your PR.

stared commented 4 years ago

With 0.5.2 it is possible to set custom sequences (thanks to @Bartolo1024):

    def _default_after_subplot(self, ax: plt.Axes, group_name: str, x_label: str):
        """Add title xlabel and legend to single chart
        Args:
            ax: matplotlib Axes
            group_name: name of metrics group (eg. Accuracy, Recall)
            x_label: label of x axis (eg. epoch, iteration, batch)
        """
        ax.set_title(group_name)
        ax.set_xlabel(x_label)
        ax.legend(loc='center right')

    def _default_before_plots(self, fig: plt.Figure, num_of_log_groups: int) -> None:
        """Set matplotlib window properties
        Args:
            fig: matplotlib Figure
            num_of_log_groups: number of log groups
        """
        clear_output(wait=True)
        figsize_x = self.max_cols * self.cell_size[0]
        figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
        fig.set_size_inches(figsize_x, figsize_y)

    def _default_after_plots(self, fig: plt.Figure):
        """Set properties after charts creation
        Args:
            fig: matplotlib Figure
        """
        fig.tight_layout()

It can be altered with PlotLosses(outputs=[MatplotlibPlot(before_plots=..., after_plots=...)])

See https://github.com/stared/livelossplot/blob/master/examples/various_options.ipynb for inspiration (with examples of how to change labels or other, with after_subplot keyword argument.

Note: right now plt.show() is not (yet?) in after_plots. If you have a use case to separate it, we would be happy to do so.