Open ScottHoang opened 3 months ago
Hi @ScottHoang thanks for creating the issue. Can you share the command you ran and how you saved the checkpoint file you're trying to load? I'm a bit surprised it's trying to load a ListConfig
, I didn't think we would be saving that in our intermediate checkpoints.
Hi @ebsmothers , I am running a custom recipe inherited from Lora_finetune_distributed. Everything else is kept the same except for _setup_model(...).
the saved function looks like this
`
def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
Relevant recipe state if training is not complete
Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model,
self._optimizer)
else:
opt_state_dict = None
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v
for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict})
checkpoint_dict.update(
{utils.ADAPTER_CONFIG: self.adapter_settings})
checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})
# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update({
utils.OPT_KEY:
opt_state_dict,
utils.SEED_KEY:
self.seed,
utils.EPOCHS_KEY:
self.epochs_run,
utils.TOTAL_EPOCHS_KEY:
self.total_epochs,
utils.MAX_STEPS_KEY:
self.max_steps_per_epoch,
})
self._checkpointer.save_checkpoint(
state_dict=checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
`
Hi @ScottHoang, sorry for the delay here. I notice that your save_checkpoint
method includes this line:
checkpoint_dict.update({utils.ADAPTER_CONFIG: self.adapter_settings})
I am curious how you've defined self.adapter_settings
in your custom recipe. Is it different than the value of adapter_config
here? I would suggest inspecting the type of each value you're adding to checkpoint_dict
to see if it matches what you expect. For what it's worth we do test resuming from an intermediate checkpoint in our LoRA recipe (see here) to ensure everything works as expected.
@ebsmothers indeed self.adapter_settings contains all the configs specific to my adapters in the yaml configs. But based on this, I assume it isn't saved in "recipe_state.pt" itself? https://github.com/pytorch/torchtune/blob/069b12bef0b9cf735d5fb7cdc4192bfbf9abd764/torchtune/utils/_checkpointing/_checkpointer.py#L593 I also checked the file itself and found no 'adapter_config' key saved. but I still can't load it with "weight_only=True."
@ScottHoang oh yeah good point. Sorry I had forgotten that the issue was with loading from recipe_state.pt
when making that comment. Can you just add a print statement here? Something like
for k, v in state_dict.items():
print(k, v, type(v))
That should help to pin down what's actually getting saved that's causing the problem.
Btw alternatively we can consider to support turning off weights_only
loading of intermediate state. But the whole reason we have this on is to ensure that everything in the intermediate state is as expected (since it'll impact your resumed training); so I think in this case it's actually helping to catch something being saved incorrectly.
Location: torchtune/utils/_checkpointing/_checkpointer.py line: 438
Error: _pickle.UnpicklingError: Weights only load failed. Re-running
torch.load
withweights_only
set toFalse
will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class omegaconf.listconfig.ListConfigsuggest fix: set weights_only to False