mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
483 stars 66 forks source link

Cannot save any torch tensor or `nn_module` #1112

Open rdinnager opened 9 months ago

rdinnager commented 9 months ago

Ever since I updated torch to the latest version, I have not been able to save any torch object using the default serialization method. I have instead had to fall back on options(torch.serialization_version = 2) to be able to save any tensors or model objects. Here is a reprex:

library(torch)

x <- torch_rand(100, 1000)
torch_save(x, "test.pth")
#> Error in `FUN()`:
#> ! `metadata` must be a named list of scalar characters.
#> Backtrace:
#>      ▆
#>   1. ├─torch::torch_save(x, "test.pth")
#>   2. └─torch:::torch_save.torch_tensor(x, "test.pth")
#>   3.   └─torch:::torch_save_to_file(...)
#>   4.     └─safetensors::safe_save_file(state_dict, path = con, metadata = metadata)
#>   5.       └─safetensors:::write_safe(tensors, metadata, con)
#>   6.         └─safetensors:::make_meta(tensors, metadata)
#>   7.           └─safetensors:::validate_metadata(metadata)
#>   8.             └─base::lapply(...)
#>   9.               └─safetensors (local) FUN(X[[i]], ...)
#>  10.                 └─cli::cli_abort("{.arg metadata} must be a named list of scalar characters.")
#>  11.                   └─rlang::abort(...)

## trying to use safetensors directly gives a different error:
safetensors::safe_save_file(x, "test.pth")
#> Error in for (tensor in tensors) {: invalid for() loop sequence

## This works (but seems very slow for some reason)
options(torch.serialization_version = 2)
torch_save(x, "test.pth")

Created on 2023-10-06 with reprex v2.0.2

Any ideas what is going wrong here?

dfalbel commented 9 months ago

Hi @rdinnager,

Thanks for reporting.

That's weird, I'd assume this is a mismatch between the torch version and the safetensors versions, as at some point I think I saw some similar issue.

Can you try updating your safetensors package. I just tried lates commit from torch + (CRAN or latest commit) safetensors and they seem to work correctly.

rdinnager commented 9 months ago

I am using the latest version of both torch and safetensors from CRAN:

packageVersion("safetensors")
#> [1] '0.1.2'
packageVersion("torch")
#> [1] '0.11.0'

I'm thinking now that I actually need the development version of torch to work properly, after looking through the recent commit history. I will try that!

dfalbel commented 8 months ago

ohhh, I think that might be the case. You are right, you might need to downgrade safetensors or use the dev version of torch. I'm going to make a new torch release soon.

rdinnager commented 6 months ago

Yes, I decided to wait until the new release and just use options(torch.serialization_version = 2) in the mean time. I find the new precompiled cuda binary method of installation so convenient I just don't want to bother trying to install from source at the moment, which would require me installing the compatible CUDA locally (and I'm not up for that right now ;) ).