mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
498 stars 68 forks source link

torch_save() and torch_load() does not work for large tensors #1210

Open D-Maar opened 1 week ago

D-Maar commented 1 week ago

If I save a large tensor (5000x10x300x300) with torch_save() and then try to load it with torch_load() I get the error:

Error in cpp_tensor_load(obj$values, device, base64) : 
  PytorchStreamReader failed reading zip archive: failed finding central directory
Exception raised from valid at ../caffe2/serialize/inline_container.cc:183 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f584f9563cb in /home/.../R/x86_64-pc-linux-gnu-library/4.2/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7f584f951d9e in /home/.../R/x86_64-pc-linux-gnu-library/4.2/torch/lib/libc10.so)
frame #2: caffe2::serialize::PyTorchStreamReader::valid(char const*, char const*) + 0x3ca (0x7f575f31631a in /home/.../R/x86_64-pc-linux-gnu-library/4.2/torch/lib/libtorch_cpu.so)
frame #3: caffe2::serialize::PyTorchStreamReader::init() + 0xad (0x7f575f316bad in /home/.../R/x86_64-pc-linux-gnu-library/

If I do the same with a smaller tensor (500x10x300x300) it just works fine.

Example:

#Does not work
dims<-c(5000,10,300,300)
tensor<-torch::torch_tensor(array((1:prod(dims))/length(dims), dims), dtype = torch::torch_float16())
torch::torch_save(tensor, "./test.pt")
tensor<-torch::torch_load("./test.pt")

#Does work 
dims<-c(500,10,300,300)
tensor<-torch::torch_tensor(array((1:prod(dims))/length(dims), dims), dtype = torch::torch_float16())
torch::torch_save(tensor, "./test.pt")
tensor<-torch::torch_load("./test.pt")

Also:

  1. Adding compress = F, did not fix it.
  2. The size of the resulting file seems way to small in the non-working case (445 KB with compression; 400.5 MB without compression).
  3. serializing it with torch:::torch_serialize() and then loading it with torch::torch_load() (no storing between the serialization and the loading; just keeping it in the RAM) does also not work and yields the same error.