Closed ppdebem closed 1 year ago
Hi @ppdebem,
Thanks very much for the detailed report. I could reproduce and fix the problem. In the current dev version of torch, we have made significant changes to serialization. We now use safetensors as the serialization format powered by the safetensors package. The problem was that safetensors was no copying cuda tensors to the cpu before saving, which caused the segfault.
Unfortunatelly you need to install another R package dev version:
remotes::install_github("mlverse/safetensors")
Please re-open if you still have issues.
My current environment versions are as follows:
R crashes when executing the
luz_save
function at the end of the basic CNN example and the same happens when using my personal models. Attempting to save checkpoints through the functionluz_callback_model_checkpoint
also results in a crash. This happens regardless if im using R through the command line, Rgui or RStudio, and apparently leaves no warnings or logs. I am not entirely sure if the crashes while saving GPU tensors are a problem on my side or torch related (torch_save
also crashes R if i don't move the tensors to CPU) due to my lack of warnings and crash logs.While investigating, i attempted to modify the functions to move the state_dict tensors to CPU before saving and that seemed to solve the issue. Simply adding
$to(device = 'cpu')
while using themodel_to_raw
function solved the issue when trying to save the fitted object through theluz_save
function:Solving it for the checkpoint callback function requires more messing around as you have to move the state dict tensors through the
luz_checkpoint
function. Something like the following seems to work at first glance:Please let me know if this is only a problem on my side and/or how to properly check that. My solutions might also not be the best as im not a good programmer, so feel free to improve them if necessary.