Open ishaaq opened 3 days ago
@ishaaq TLDR You can hack the create_model_and_transforms call in main.py right now to include load_weights_only=False
, we need to add an arg with a sensible name to allow this override from cmd line.
Current train checkpoints are not weights only, so without cleaning the checkpoint to remove the extra items in the state dict, load_weights_only=False
needs to be set in the factory but that arg isn't exposed to main.py, thought the usual use case was to use --resume
but guess if one wants to use a train checkpoint as an initial checkpoint for a new train session that doesn't work...
Hmm, so it's appropriate/recommended (for security) to leave the flag defaulted to True
for training on pretrained models but the only way to train on a checkpoint is to turn it off? Surely the security concerns don't go away with using checkpoints?
Thanks, yeah - I did hack the call to set it as a workaround for now.
@ishaaq I wrote all below and then I just though, what pytorch ver are you using? I just tried torch.load(..., weights_only=True)
on a recent train checkpoint (loaded in pytorch 2.5, train was using 2.2 I think). It worked.
There were a number of late additions to safe globals defaults in pytorch as this change evolved on their end. I don't feel we're saving any funky types in our train checkpoints by default, so curious what it's breaking on or if pytorch 2.5 fixes it?
pretrained model weights are hosted and distributed online, training checkpoints are artifacts of your own training that were generated locally (though they could be shared, we don't distribute them). So the 'resume' functionality expects to be operating on the output of your own local training.
I don't see requiring a --pretrained-force-unsafe-load
or something along those lines as being dangerous, it can be made clear that well 'it's not safe', and you should only do it to load your own training checkpoints as pretrained models.
I see two paths
In my minimal repro I had opted to do nothing else other than create a venv and then pip install open_clip_torch[training]
and let it decide dependency versions. The torch version it opted to download was v2.5.1:
Collecting open_clip_torch[training]
Using cached open_clip_torch-2.29.0-py3-none-any.whl (1.5 MB)
...
Collecting torch>=1.9.0
Using cached torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl (906.5 MB)
... so yes, the issue exists even in torch 2.5.1
Back to your answer though... I understand it's not safe to load anything other than weights from an external/untrusted pretrained source, but are you saying that it's still safe to trust our own training checkpoints even if the training that generated those checkpoints were based off of an untrusted pretrained source? i.e. that's what my minimal repro above is doing.
@ishaaq and in your example the checkpoint you were trying to load was also created by that env or from somewhere else? You trained a ViT-B-32 for an epoch and then it failed? weird, I don't see how a numpy global got in there. We ideally shouldn't be generating train checkpoints that don't load with weights_only=True and not quite sure how that's happening right now.
And yes, it isn't necessarily 'safe' to load a train checkpoint with weights_only=False generated by code you did not audit yourself. Not aware of any abuse there but it's possible.
The minimal repro is exactly that - i.e. the same env - I spun up an EC2 GPU instance and pip installed open_clip and ran both the first training run (which completed successfully) and then the second using the checkpoint from the first (which failed with the pickling error).
I am getting an
pickle.UnpicklingError
when trying to train again on a previously trained checkpoint with open_clipv2.27.0+
.This is similar to https://github.com/mlfoundations/open_clip/issues/966 - but the problem persists even in the latest
v2.29.0
Minimal repro: in v2.26.1, the following used to work:
... but from v2.27.0 onwards it fails: