mlr-org / mlr3torch

Deep learning framework for the mlr3 ecosystem based on torch
https://mlr3torch.mlr-org.com
Other
33 stars 6 forks source link

implement plotter for history state in mlr3viz #227

Open sebffischer opened 2 months ago

sebffischer commented 2 months ago

old code:

   #' @description Plots the history.
    #' @param measures (`character()`)\cr
    #'   Which measures to plot. No default.
    #' @param set (`character(1)`)\cr
    #'   Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
    #' @param epochs (`integer()`)\cr
    #'   An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
    #' @param theme ([ggplot2::theme()])\cr
    #'   The theme, [ggplot2::theme_minimal()] is the default.
    #' @param ... (any)\cr
    #'   Currently unused.
    plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
      assert_choice(set, c("valid", "train"))
      data = self[[set]]
      assert_subset(measures, colnames(data))

      if (is.null(epochs)) {
        data = data[, c("epoch", measures), with = FALSE]
      } else {
        assert_integerish(epochs, unique = TRUE)
        data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
      }

      if ((!nrow(data)) || (ncol(data) < 2)) {
        stopf("No eligible measures to plot for set '%s'.", set)
      }

      epoch = score = measure = .data = NULL
      if (ncol(data) == 2L) {
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = measures,
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      } else {
        data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
          viridis::scale_color_viridis(discrete = TRUE) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = "Score",
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      }
sebffischer commented 2 months ago

this should dispatch on LearnerTorch