tidymodels / brulee

High-Level Modeling Functions with 'torch'
https://brulee.tidymodels.org/
Other
67 stars 7 forks source link

Simplify `matrix_to_dataset()` using `torch::tensor_dataset()` #40

Closed hsbadr closed 2 years ago

hsbadr commented 3 years ago

I suggest to simplify matrix_to_dataset() using torch::tensor_dataset(), which also check the dimensions.

This will replace the following lines https://github.com/tidymodels/lantern/blob/db48129496d9fe1ffcf85e1a595ca9d5c8e85461/R/convert_data.R#L15-L42 with something like:

matrix_to_dataset <- function(x, y) {
  x <- torch::torch_tensor(x)
  if (is.factor(y)) {
    y <- as.numeric(y)
    y <- torch::torch_tensor(y, dtype = torch_long())
  } else {
    y <- torch::torch_tensor(y)
  }
  torch::tensor_dataset(x = x, y = y)
}

If you agree, I can create a PR.

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.