mlverse / luz

Higher Level API for torch
https://mlverse.github.io/luz/
Other
85 stars 13 forks source link

How to just continue training #113

Closed GitHubGeniusOverlord closed 1 year ago

GitHubGeniusOverlord commented 2 years ago

Hello!

I scanned to all the documents, but could not find anything on this. There is the callback, luz_callback_resume_from_checkpoint, which allows to continue from one certain checkpoint. But what if I just want to continue training from the latest state? How can this be done?

Maybe we could add an example for this to the documentation. Thank You!

dfalbel commented 2 years ago

We currently don't have a nice way of doing it. That's something we would like to have, but it still didn't have a good consensus of what should be considered 'continuing training' for example:

You can do something like this though, if you don't care about optimizer state. Otherwise the recommended way is to use a the resume from checkpoint callback and checkpoint the model state that you want to recover from.

  1. Train a model:
library(luz)
library(torch)

model <- nn_module(
  initialize = function() {
    self$linear <- nn_linear(10, 1)
  },
  forward = function(x) {
    self$linear(x)
  }
)

fitted <- model %>% 
  setup(loss = nnf_mse_loss, optimizer = optim_adam) %>% 
  set_hparams() %>% 
  set_opt_hparams(lr = 0.01) %>% 
  fit(list(
    torch_randn(10, 10),
    torch_randn(100, 1)
  ))
  1. Create a wrapper model and train it:
model2 <- nn_module(
  initialize = function(model) {
    self$model <- model
  },
  forward = function(x) {
    self$model(x)
  }
)

fitted <- model2 %>% 
  setup(loss = nnf_mse_loss, optimizer = optim_adam) %>% 
  set_hparams(model = fitted$model) %>% 
  set_opt_hparams(lr = 0.01) %>% 
  fit(list(
    torch_randn(10, 10),
    torch_randn(100, 1)
  ))
GitHubGeniusOverlord commented 1 year ago

I guess, this is outdated, as there are updates in luz on how to continue training. Closing therefore.