iterative / dvclive

📈 Log and track ML metrics, parameters, models with Git and/or DVC
https://dvc.org/doc/dvclive
Apache License 2.0
161 stars 33 forks source link

log_sklearn_plot: exception when plot params are passed. #812

Closed ocraft closed 3 months ago

ocraft commented 3 months ago

When plot params (title, x_label etc.) are passed, e.g.:

from dvclive import Live

with Live() as live:
  y_true = [0, 0, 1, 1]
  y_pred = [1, 0, 1, 0]
  y_score = [0.1, 0.4, 0.35, 0.8]
  live.log_sklearn_plot("roc", y_true, y_score)
  live.log_sklearn_plot(
    "confusion_matrix", y_true, y_pred, name="cm.json", title="Test")

the exception "TypeError: got an unexpected keyword argument 'title'" is thrown.

It's because of this line in the log_sklearn_plot method:

sklearn_kwargs = {
    k: v for k, v in kwargs.items() if k not in plot_config or k != "normalized"
}

The condition if k not in plot_config or k != "normalized" is always true for all params different that "normalized".

dberenbaum commented 3 months ago

Good catch! It looks like normalized doesn't even need to be included here since it's already added to plot_config above. I also don't see why normalized is in kwargs at all instead of being a named argument. And other plot_config options title, x_label, and y_label seem to be undocumented, and I'm not sure they are really needed. I would suggest making normalized a named argument when fixing this.

@ocraft Are you interested in submitting a PR?

ocraft commented 3 months ago

Sure, I've fixed that in the PR: https://github.com/iterative/dvclive/pull/813. Interestingly, there was already a test for it, but only for "confusion_matrix" and "title", and this particular plot has such an argument, so the test passed. The plot's parameters themselves should remain, as they are useful; for example, the title is necessary to distinguish between two different plots of the same type in the same report. In other parts of the application, the "normalized" parameter is utilized, so I didn't change anything else.

dberenbaum commented 3 months ago

Interesting, thanks. The reason it passes for confusion_matrix is that plot doesn't actually call any sklearn method, which is where the error gets raised.