Closed hsbadr closed 2 years ago
I suggest to simplify matrix_to_dataset() using torch::tensor_dataset(), which also check the dimensions.
matrix_to_dataset()
torch::tensor_dataset()
This will replace the following lines https://github.com/tidymodels/lantern/blob/db48129496d9fe1ffcf85e1a595ca9d5c8e85461/R/convert_data.R#L15-L42 with something like:
matrix_to_dataset <- function(x, y) { x <- torch::torch_tensor(x) if (is.factor(y)) { y <- as.numeric(y) y <- torch::torch_tensor(y, dtype = torch_long()) } else { y <- torch::torch_tensor(y) } torch::tensor_dataset(x = x, y = y) }
If you agree, I can create a PR.
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.
I suggest to simplify
matrix_to_dataset()
usingtorch::tensor_dataset()
, which also check the dimensions.This will replace the following lines https://github.com/tidymodels/lantern/blob/db48129496d9fe1ffcf85e1a595ca9d5c8e85461/R/convert_data.R#L15-L42 with something like:
If you agree, I can create a PR.