mlverse / luz

Higher Level API for torch
https://mlverse.github.io/luz/
Other
79 stars 12 forks source link

R crashes when attempting to save a model or make checkpoints before moving tensors to CPU #131

Closed ppdebem closed 1 year ago

ppdebem commented 1 year ago

My current environment versions are as follows:

R crashes when executing the luz_save function at the end of the basic CNN example and the same happens when using my personal models. Attempting to save checkpoints through the function luz_callback_model_checkpoint also results in a crash. This happens regardless if im using R through the command line, Rgui or RStudio, and apparently leaves no warnings or logs. I am not entirely sure if the crashes while saving GPU tensors are a problem on my side or torch related (torch_save also crashes R if i don't move the tensors to CPU) due to my lack of warnings and crash logs.

While investigating, i attempted to modify the functions to move the state_dict tensors to CPU before saving and that seemed to solve the issue. Simply adding $to(device = 'cpu') while using the model_to_raw function solved the issue when trying to save the fitted object through the luz_save function:

luz_save <- function(obj, path, ...) {
  ellipsis::check_dots_empty()
  # dangling environments might be in the `obj` search path causing problems
  # during saving. `gc()` is a good practice to make sure they are cleaned up
  # before saving.
  gc()

  if (!inherits(obj, "luz_module_fitted"))
    rlang::abort("luz_save only works with 'luz_module_fitted_objects' and got {class(obj)[1]}")

  # avoid warning because luz will always be available when reloading
  # because we reload with `luz_load()`.
  suppressWarnings({
    serialized_model <- model_to_raw(obj$model$to(device = 'cpu')) # Move tensors to cpu before saving
    obj$ctx$.serialized_model <- serialized_model
    obj$ctx$.serialization_version <- 2L
    o <- saveRDS(obj, path)
  })

  invisible(o)
}

Solving it for the checkpoint callback function requires more messing around as you have to move the state dict tensors through the luz_checkpoint function. Something like the following seems to work at first glance:

# Function to move tensors to CPU recursively through the lists
move_to_cpu <- function(state_dict){
  rapply(state_dict, function(x){
    x$to(device = 'cpu')
  }, classes = 'torch_tensor', how = 'replace')
}

luz_checkpoint <- function(ctx, path) {
  state <- list()

  #grab epoch
  state[["epoch"]] <- ctx$epoch
  state[["records"]] <- ctx$records

  # grab model state
  state[["model"]] <- ctx$model$state_dict()

  # grab optimizer state
  state[["optimizers"]] <- lapply(ctx$optimizers, function(x) x$state_dict()) 

  # traverse callbacks looking for the `state_dict()` method.
  state[["callbacks"]] <- lapply(ctx$callbacks, function(x) {
    if (is.null(x$state_dict))
      NULL
    else
      x$state_dict()
  })

  state <- move_to_cpu(state) # Apply the function to move every torch_tensor to cpu device
  torch_save(state, path)
}

Please let me know if this is only a problem on my side and/or how to properly check that. My solutions might also not be the best as im not a good programmer, so feel free to improve them if necessary.

dfalbel commented 1 year ago

Hi @ppdebem,

Thanks very much for the detailed report. I could reproduce and fix the problem. In the current dev version of torch, we have made significant changes to serialization. We now use safetensors as the serialization format powered by the safetensors package. The problem was that safetensors was no copying cuda tensors to the cpu before saving, which caused the segfault.

Unfortunatelly you need to install another R package dev version:

remotes::install_github("mlverse/safetensors")

Please re-open if you still have issues.