pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.21k stars 410 forks source link

safe_torch_load failed when resume from checkpoint #1142

Open ScottHoang opened 3 months ago

ScottHoang commented 3 months ago

Location: torchtune/utils/_checkpointing/_checkpointer.py line: 438

Error: _pickle.UnpicklingError: Weights only load failed. Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class omegaconf.listconfig.ListConfig

suggest fix: set weights_only to False

ebsmothers commented 3 months ago

Hi @ScottHoang thanks for creating the issue. Can you share the command you ran and how you saved the checkpoint file you're trying to load? I'm a bit surprised it's trying to load a ListConfig, I didn't think we would be saving that in our intermediate checkpoints.

ScottHoang commented 3 months ago

Hi @ebsmothers , I am running a custom recipe inherited from Lora_finetune_distributed. Everything else is kept the same except for _setup_model(...).

ScottHoang commented 3 months ago

the saved function looks like this

`
def save_checkpoint( self, epoch: int, ) -> None: """ Checkpoint the state of the recipe. The constructed checkpoint state dict contains the following information:

ebsmothers commented 3 months ago

Hi @ScottHoang, sorry for the delay here. I notice that your save_checkpoint method includes this line:

checkpoint_dict.update({utils.ADAPTER_CONFIG: self.adapter_settings})

I am curious how you've defined self.adapter_settings in your custom recipe. Is it different than the value of adapter_config here? I would suggest inspecting the type of each value you're adding to checkpoint_dict to see if it matches what you expect. For what it's worth we do test resuming from an intermediate checkpoint in our LoRA recipe (see here) to ensure everything works as expected.

ScottHoang commented 3 months ago

@ebsmothers indeed self.adapter_settings contains all the configs specific to my adapters in the yaml configs. But based on this, I assume it isn't saved in "recipe_state.pt" itself? https://github.com/pytorch/torchtune/blob/069b12bef0b9cf735d5fb7cdc4192bfbf9abd764/torchtune/utils/_checkpointing/_checkpointer.py#L593 I also checked the file itself and found no 'adapter_config' key saved. but I still can't load it with "weight_only=True."

ebsmothers commented 3 months ago

@ScottHoang oh yeah good point. Sorry I had forgotten that the issue was with loading from recipe_state.pt when making that comment. Can you just add a print statement here? Something like

for k, v in state_dict.items():
    print(k, v, type(v))

That should help to pin down what's actually getting saved that's causing the problem.

Btw alternatively we can consider to support turning off weights_only loading of intermediate state. But the whole reason we have this on is to ensure that everything in the intermediate state is as expected (since it'll impact your resumed training); so I think in this case it's actually helping to catch something being saved incorrectly.