tidymodels / brulee

High-Level Modeling Functions with 'torch'
https://brulee.tidymodels.org/
Other
67 stars 7 forks source link

Loss computations fail with SGD #73

Open topepo opened 1 year ago

topepo commented 1 year ago

Somewhat due to randomness; different random numbers may not fail

library(tidymodels)
library(brulee)

tidymodels_prefer()

data(ames, package = "modeldata")

ames$Sale_Price <- log10(ames$Sale_Price)

set.seed(122)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test  <- ames[-in_train,]

set.seed(1)
brulee_linear_reg(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
                  y = ames_train$Sale_Price,
                  penalty = 0.10, epochs = 10, batch_size = 64, 
                  optimizer = "SGD", verbose = TRUE)
#> Warning: Current loss in NaN. Training wil be stopped.
#> Linear regression
#> 
#> 2,000 samples, 2 features, numeric outcome 
#> weight decay: 0.1 
#> batch size: 64 
#> scaled validation loss after 1 epoch: NaN

Created on 2023-11-02 with reprex v2.0.2