mlfoundations / open_clip

An open source implementation of CLIP.
Other
10.38k stars 986 forks source link

Cannot train again on pretrained checkpoint due to change in default `weights_only=True` #998

Open ishaaq opened 3 days ago

ishaaq commented 3 days ago

I am getting an pickle.UnpicklingError when trying to train again on a previously trained checkpoint with open_clip v2.27.0+.

This is similar to https://github.com/mlfoundations/open_clip/issues/966 - but the problem persists even in the latest v2.29.0

Minimal repro: in v2.26.1, the following used to work:

python -m open_clip_train.main \
  --dataset-type synthetic \
  --train-num-samples 16 \
  --warmup 1 \
  --batch-size 4 \
  --epochs 1 \
  --model ViT-B-32 \
  --name train1 \
  --pretrained laion400m_e31

# train again, but this time using a checkpoint from the previous training:
python -m open_clip_train.main \
  --dataset-type synthetic \
  --train-num-samples 16 \
  --warmup 1 \
  --batch-size 4 \
  --epochs 1 \
  --model ViT-B-32 \
  --name train2 \
  --pretrained ./logs/train1/checkpoints/epoch_1.pt

... but from v2.27.0 onwards it fails:

2024-11-20,09:28:39 | INFO | Loading pretrained ViT-B-32 weights (./logs/train1/checkpoints/epoch_1.pt).
Traceback (most recent call last):
  File "/home/ec2-user/.pyenv/versions/3.8.20/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ec2-user/.pyenv/versions/3.8.20/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip_train/main.py", line 508, in <module>
    main(sys.argv[1:])
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip_train/main.py", line 223, in main
    model, preprocess_train, preprocess_val = create_model_and_transforms(
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip/factory.py", line 414, in create_model_and_transforms
    model = create_model(
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip/factory.py", line 320, in create_model
    load_checkpoint(model, checkpoint_path)
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip/factory.py", line 169, in load_checkpoint
    state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/open_clip/factory.py", line 139, in load_state_dict
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
  File "/home/ec2-user/.venv/lib/python3.8/site-packages/torch/serialization.py", line 1096, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options
    (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
    WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
rwightman commented 1 day ago

@ishaaq TLDR You can hack the create_model_and_transforms call in main.py right now to include load_weights_only=False, we need to add an arg with a sensible name to allow this override from cmd line.

Current train checkpoints are not weights only, so without cleaning the checkpoint to remove the extra items in the state dict, load_weights_only=False needs to be set in the factory but that arg isn't exposed to main.py, thought the usual use case was to use --resume but guess if one wants to use a train checkpoint as an initial checkpoint for a new train session that doesn't work...

ishaaq commented 1 day ago

Hmm, so it's appropriate/recommended (for security) to leave the flag defaulted to True for training on pretrained models but the only way to train on a checkpoint is to turn it off? Surely the security concerns don't go away with using checkpoints?

Thanks, yeah - I did hack the call to set it as a workaround for now.

rwightman commented 1 day ago

@ishaaq I wrote all below and then I just though, what pytorch ver are you using? I just tried torch.load(..., weights_only=True) on a recent train checkpoint (loaded in pytorch 2.5, train was using 2.2 I think). It worked.

There were a number of late additions to safe globals defaults in pytorch as this change evolved on their end. I don't feel we're saving any funky types in our train checkpoints by default, so curious what it's breaking on or if pytorch 2.5 fixes it?


pretrained model weights are hosted and distributed online, training checkpoints are artifacts of your own training that were generated locally (though they could be shared, we don't distribute them). So the 'resume' functionality expects to be operating on the output of your own local training.

I don't see requiring a --pretrained-force-unsafe-load or something along those lines as being dangerous, it can be made clear that well 'it's not safe', and you should only do it to load your own training checkpoints as pretrained models.

I see two paths

  1. what I mentioned, in train code pass an option flag to the create_model that disables the weights_only when you intend to load a training checkpoint as a pretrained weight. By default it's left as True and all of the normal pretrained weights work because they are just bare model state dicts, stripped of their training state.
  2. modify the --resume functionality with extra flags to ignore epoch or not load optimizer state, this mostly works but it then wouldn't work with the resizing, etc functionality that is integrated into the factory
ishaaq commented 1 day ago

In my minimal repro I had opted to do nothing else other than create a venv and then pip install open_clip_torch[training] and let it decide dependency versions. The torch version it opted to download was v2.5.1:

Collecting open_clip_torch[training]
  Using cached open_clip_torch-2.29.0-py3-none-any.whl (1.5 MB)
...
Collecting torch>=1.9.0
  Using cached torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl (906.5 MB)

... so yes, the issue exists even in torch 2.5.1

Back to your answer though... I understand it's not safe to load anything other than weights from an external/untrusted pretrained source, but are you saying that it's still safe to trust our own training checkpoints even if the training that generated those checkpoints were based off of an untrusted pretrained source? i.e. that's what my minimal repro above is doing.

rwightman commented 19 hours ago

@ishaaq and in your example the checkpoint you were trying to load was also created by that env or from somewhere else? You trained a ViT-B-32 for an epoch and then it failed? weird, I don't see how a numpy global got in there. We ideally shouldn't be generating train checkpoints that don't load with weights_only=True and not quite sure how that's happening right now.

And yes, it isn't necessarily 'safe' to load a train checkpoint with weights_only=False generated by code you did not audit yourself. Not aware of any abuse there but it's possible.

ishaaq commented 11 hours ago

The minimal repro is exactly that - i.e. the same env - I spun up an EC2 GPU instance and pip installed open_clip and ran both the first training run (which completed successfully) and then the second using the checkpoint from the first (which failed with the pickling error).