huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
14.61k stars 831 forks source link

Unable to read older pytorch pickle files #1348

Open joeyballentine opened 8 months ago

joeyballentine commented 8 months ago

It seems candle's pickle support currently assumes the newer pickle format torch switched to some time ago. (At least, I'm assuming this is the case). This is a slight issue as many of the .pth files I'd like to be able to load using candle are using this older format.

Specifically, I'm getting a Zip(InvalidArchive("Could not find central directory end")) when trying to use pickle::read_all on many older .pth files. Newer ones seem to work fine, but those are also unusable due to the load key issue I've mentioned previously.

My hunch is that this is due to torch switching how they pickle their path files. I remember at some point torch.save had an option to use the newer zip format, so I'm guessing these old ones are using something other than zip to store the pickle data, and thus the zip library can't read it properly. I'm not sure how to confirm this though.

I've added support for .pth reading to my candle ESRGAN repo, where this can be pretty easily recreated. One of these model files should trigger it, as they are older files:

If you need a smaller repro, I can try to provide one when I have time tomorrow.

joeyballentine commented 8 months ago

Did a little digging, and here is the legacy model loading code in pytorch, for reference

LaurentMazare commented 8 months ago

Thanks for reporting this, I don't think the issue is within the zip crate as the files appear not to be in the zip format but rather just be a stream of serialized pickle. E.g. the following manages to read the start of the file (but fails with some torch error that I haven't investigated).

import pickle
filename = '4x_Deviance_60000G.pth'

def persistent_load(persid):
    print(persid)

with open(filename, 'rb') as fobj:
    p = pickle.Unpickler(fobj)
    p.persistent_load = persistent_load

    while True:
        data = p.load()
        print(data)

In principle, we should be able to tweak the candle_core::pickle module to try to read this format but it's tricky to know how hard that will end up being without giving it a try (as we don't support the whole space of pickle files at all). As this format seems deprecated compared to the new zip based approach, I would be tempted not to spend time supporting it but if you want to give it a try, certainly feel free.

joeyballentine commented 8 months ago

Yeah, I can understand that it probably isn't worth the effort. Have you considered trying to use one of the serde-pickle packages? I tried implementing read_all using serde-pickle-rs and it seemed to go smoothly making all the tensor related code serializable and deserializable, but unfortunately I ran into an issue with some complicated rust dyn typing I'm not familiar with.

Anyway, it's not too big of a deal. I can just expect any users to convert the models to safetensors beforehand. Given this issue and the load key one, it's just the simpler method.