tidymodels / brulee

High-Level Modeling Functions with 'torch'
https://brulee.tidymodels.org/
Other
67 stars 7 forks source link

fix optimizer bug in #61 #70

Closed topepo closed 10 months ago

topepo commented 10 months ago

One issue I'm having with SGD is missing metrics:

library(tidymodels)
library(brulee)

tidymodels_prefer()

data(ames, package = "modeldata")

ames$Sale_Price <- log10(ames$Sale_Price)

set.seed(122)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test  <- ames[-in_train,]

set.seed(1)
brulee_linear_reg(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
                  y = ames_train$Sale_Price,
                  penalty = 0.10, epochs = 10, batch_size = 64, 
                  optimizer = "SGD", verbose = TRUE)
#> Warning: Current loss in NaN. Training wil be stopped.
#> Linear regression
#> 
#> 2,000 samples, 2 features, numeric outcome 
#> weight decay: 0.1 
#> batch size: 64 
#> scaled validation loss after 1 epoch: NaN

Created on 2023-11-01 with reprex v2.0.2

topepo commented 10 months ago

Made a separate issue for the SGD note above.

github-actions[bot] commented 10 months ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.