neural-structured-additive-learning / safareg

GNU General Public License v3.0
0 stars 0 forks source link

Example code with movielens #1

Open tdeenes opened 1 year ago

tdeenes commented 1 year ago

In the paper 'Factorized Structured Regression for Large-Scale Varying Coefficient Models' you analyzed the MovieLens benchmark data. It would be extremely helpful to add a vignette including that analysis (or a simplified version of it).

Context: I would like to use the GAM-approach in a recommender task where we have time-varying user features and static item features. mgcv::bam or lme4::glmer can not be used as the number of users is >150k and the number of items is ~600. safareg seems capable of modeling this problem but when I tried it on a highly simplified example I got totally wrong results so I think I am doing something wrong in the model definition (see here: https://github.com/neural-structured-additive-learning/deepregression/issues/13).

davidruegamer commented 1 year ago

Hi @tdeenes , thx for your interest and good point. Will include it in our vignette.

In the meantime, here a reproducible example:

### Reprex of one FaStR model on the movies data set
### Takes around 10-12 GB in total and a few hours
### for fitting (mgcv::bam in contrast would require 
### TBs of RAM, let alone the run time)

library(mgcv)
library(data.table)
library(deepregression)
library(safareg)

# loading data will take about 2-3 GB
traintest <- readRDS("train_test.RDS")
train_set <- traintest[[1]]
test_set <- traintest[[2]]
mints <- min(train_set$timestamp)
train_set$timestamp <- train_set$timestamp - mints
test_set$timestamp <- test_set$timestamp - mints

train_set$userId <- as.factor(train_set$userId)
train_set$movieId <- as.factor(train_set$movieId)

test_set$userId <- factor(test_set$userId, levels = levels(train_set$userId))
test_set$movieId <- factor(test_set$movieId, levels = levels(train_set$movieId))

RMSE <- function(true_ratings, predicted_ratings){
  sqrt(mean((true_ratings - predicted_ratings)^2))
}

# model settings (this is one of the largest models in terms
# of latent factors)

bs = 5000L # batch size
dim_fm = 17L # dimension of factorization for interaction
dim_vfm = 10L # dimension of factorization for time-varying interaction
k = 15 # number of spline knots
la_user = "NULL" # no extra regularization (implicit via early stopping)
la_movie = "NULL" # no extra regularization (implicit via early stopping)

# Define the spline
sterm = paste0("s(timestamp, k = ", k, ")")

form = paste0("~ fac(userId, la = ", la_user, ") + fac(movieId, la = ", la_movie, 
              ") + facz(userId, by = movieId, dim = ", 
              dim_fm, ") + ", sterm, " + vc(", sterm, ", by = userId)",
              " + vc(", sterm, ", by = movieId)",
              " + vfacz(", sterm, ", by = c(userId, movieId), dim = ", dim_vfm, ")")

# init + fit if no saved model exist
if(!file.exists("movies_bm.hdf5")){

  # takes around 5-10min and 6GB of RAM on my machine
  mod <- deepregression(y=train_set$rating,
                        list_of_formulas = list(as.formula(form),
                                                ~ 1),
                        data = train_set, 
                        optimizer = tf$keras$optimizers$Adam(learning_rate = 1e-5),
                        additional_processors = list(fac = fac_processor,
                                                     facz = fz_processor,
                                                     vc = am_processor,
                                                     vfacz = vf_processor),
                        penalty_options = penalty_control(df = 15)
  )

  # memory will be released after and then again required (roughly the same)
  # when running fit  
  # this will also take some time (see console output)
  # -> usually converges after a few iterations
  hist <- mod %>% fit(epochs = 20, batch_size = bs, 
                      early_stopping = TRUE, patience = 1L,
                      verbose = TRUE)

  # save model in case prediction crashes session
  save_model_weights_hdf5(mod$model, filepath=paste0("movies_bm.hdf5"))

}else{

  # if model already trained, load weights
  mod$model$load_weights(filepath="movies_bm.hdf5", by_name = FALSE)

}

# predictions
pred <- predict(mod, newdata = test_set, batch_size = bs)

# check RMSE
RMSE(test_set$rating, pred)

Note that this is quite a large model and (despite being so efficient in contrast to other approaches), the software requires around 10 GB of RAM and a few hours for fitting (this is due to the data size, not because of the number of users/items). I have uploaded the training-test split here on my GDrive. I would recommend to subsample the training data to first to check if the code also works for you.

Best, David

tdeenes commented 1 year ago

Thank you, @davidruegamer , that was very helpful. I was able to run the script without issues. However, the final result is very far from the one published in the paper:

> RMSE(test_set$rating, pred)
[1] 1.567889

Based on the substantial drop in the loss values even from epoch 19 to 20 I assume this is because the number of epochs should be set to a larger value for a serious analysis:

Epoch 19/20
1459/1459 [==============================] - 2785s 2s/step - loss: 2.2518
1459/1459 [==============================] - 3007s 2s/step - loss: 2.2518 - val_loss: 5.3859
Epoch 20/20
1459/1459 [==============================] - 2569s 2s/step - loss: 2.0908
1459/1459 [==============================] - 2730s 2s/step - loss: 2.0908 - val_loss: 5.1933

Can you confirm if the poor RMSE is expected with these settings?

davidruegamer commented 1 year ago

Thanks for checking! Indeed, I think this particular model specification should yield a minimum value of around 1.09 when trained for more epochs. Sorry for being a bit handwavy about the number of epochs. And yes, this value is still far from the published value in the paper. In the paper we wrote $D=1$ to be the optimal value for the movies data set, that means we used

... vfacz(", sterm, ", by = c(userId, movieId), simple = TRUE, dim = ", dim_vfm, ")" ... 

with dim_vfm = 1 (and simple = TRUE as another argument), but I need to check our server logs again if dim_fm was also 1L.