Open germanjke opened 1 month ago
We haven't thoroughly tested with layer freezing + FSDP, so possible something complicated is going on here. However, we have seen this error when you try to load a checkpoint that doesn't have optimizer state. So it is possible that loading a checkpoint only containing optimizer state for some of the parameters does not work properly.
I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with HYBRID_SHARD
I get the following error (KeyError: 'state'
):
[rank24]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:786 in <module> │
[rank24]: │ │
[rank24]: │ 783 │ cfg = om.merge(yaml_cfg, cli_cfg) │
[rank24]: │ 784 │ om.resolve(cfg) │
[rank24]: │ 785 │ assert isinstance(cfg, DictConfig) │
[rank24]: │ ❱ 786 │ main(cfg) │
[rank24]: │ 787 │
[rank24]: │ │
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:717 in main │
[rank24]: │ │
[rank24]: │ 714 │ │
[rank24]: │ 715 │ # Build the Trainer │
[rank24]: │ 716 │ log.info('Building trainer...') │
[rank24]: │ ❱ 717 │ trainer = Trainer( │
[rank24]: │ 718 │ │ run_name=run_name, │
[rank24]: │ 719 │ │ seed=seed, │
[rank24]: │ 720 │ │ model=model, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in │
[rank24]: │ __init__ │
[rank24]: │ │
[rank24]: │ 1712 │ │ │ │ if wandb.run is None: │
[rank24]: │ 1713 │ │ │ │ │ load_object_store.init(self.state, self.logger) │
[rank24]: │ 1714 │ │ │ _, _, parsed_load_path = parse_uri(load_path) │
[rank24]: │ ❱ 1715 │ │ │ self._rng_state = checkpoint.load_checkpoint( │
[rank24]: │ 1716 │ │ │ │ state=self.state, │
[rank24]: │ 1717 │ │ │ │ logger=self.logger, │
[rank24]: │ 1718 │ │ │ │ path=parsed_load_path, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:531 in │
[rank24]: │ load_checkpoint │
[rank24]: │ │
[rank24]: │ 528 │ │ │ │ │ fsdp_sharded_state_dict_enabled=state.fsdp_sharde │
[rank24]: │ 529 │ │ │ │ │ deepspeed_sharded_checkpoint=is_model_deepspeed(s │
[rank24]: │ 530 │ │ │ │ ) │
[rank24]: │ ❱ 531 │ │ │ │ rng_state_dicts = _restore_checkpoint( │
[rank24]: │ 532 │ │ │ │ │ state, │
[rank24]: │ 533 │ │ │ │ │ logger, │
[rank24]: │ 534 │ │ │ │ │ composer_states_filepath, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:999 in │
[rank24]: │ _restore_checkpoint │
[rank24]: │ │
[rank24]: │ 996 │ │ │ algorithm_passes=algorithm_passes, │
[rank24]: │ 997 │ │ ) │
[rank24]: │ 998 │ if not load_weights_only: │
[rank24]: │ ❱ 999 │ │ state.load_state_dict( │
[rank24]: │ 1000 │ │ │ state_dict['state'], │
[rank24]: │ 1001 │ │ │ logger, │
[rank24]: │ 1002 │ │ │ exclude_algorithms=exclude_algorithms, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1418 in │
[rank24]: │ load_state_dict │
[rank24]: │ │
[rank24]: │ 1415 │ │ │ if attribute_name == 'dataset_state': │
[rank24]: │ 1416 │ │ │ │ self._load_dataset_state(serialized_value) │
[rank24]: │ 1417 │ │ │ elif attribute_name == 'optimizers': │
[rank24]: │ ❱ 1418 │ │ │ │ self.load_optim_state(state) │
[rank24]: │ 1419 │ │ │ elif attribute_name == 'train_metrics': │
[rank24]: │ 1420 │ │ │ │ # Get current metrics object and populate each metric │
[rank24]: │ 1421 │ │ │ │ # in serialization with serialized data via load_stat │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1331 in │
[rank24]: │ load_optim_state │
[rank24]: │ │
[rank24]: │ 1328 │ │ │ │ # errors) before discarding the output. Accordingly, │
[rank24]: │ 1329 │ │ │ │ # See: https://github.com/pytorch/pytorch/issues/1251 │
[rank24]: │ 1330 │ │ │ │ optim_state_dict = MagicMock() if optim_state_dict is │
[rank24]: │ ❱ 1331 │ │ │ │ set_optimizer_state_dict( │
[rank24]: │ 1332 │ │ │ │ │ model=self.model, │
[rank24]: │ 1333 │ │ │ │ │ optimizers=optimizer, │
[rank24]: │ 1334 │ │ │ │ │ optim_state_dict=optim_state_dict, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/mosaic_fsdp_utils.py:719 │
[rank24]: │ in set_optimizer_state_dict │
[rank24]: │ │
[rank24]: │ 716 │ │ │ info = _verify_options(model, optimizers, optim_only=True, │
[rank24]: │ 717 │ │ │ │
[rank24]: │ 718 │ │ │ _verify_state_dict({}, optim_state_dict, info) │
[rank24]: │ ❱ 719 │ │ │ _load_optim_state_dict(model, optimizers, optim_state_dict │
[rank24]: │ 720 │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py │
[rank24]: │ :616 in _load_optim_state_dict │
[rank24]: │ │
[rank24]: │ 613 │ │ │ │ │ │ osd_state[k.replace(fqn, fqn_with_compiler)] │
[rank24]: │ 614 │ │ │ │
[rank24]: │ 615 │ │ │ with info.fsdp_context(): │
[rank24]: │ ❱ 616 │ │ │ │ optim_state_dict = FSDP.optim_state_dict_to_load( │
[rank24]: │ 617 │ │ │ │ │ model, optim, optim_state_dict │
[rank24]: │ 618 │ │ │ │ ) │
[rank24]: │ 619 │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1928 in optim_state_dict_to_load │
[rank24]: │ │
[rank24]: │ 1925 │ │ │ │ Default: ``None``) │
[rank24]: │ 1926 │ │ """ │
[rank24]: │ 1927 │ │ state_dict_settings = FullyShardedDataParallel.get_state_dict │
[rank24]: │ ❱ 1928 │ │ result = FullyShardedDataParallel._optim_state_dict_to_load_i │
[rank24]: │ 1929 │ │ │ optim_state_dict=optim_state_dict, │
[rank24]: │ 1930 │ │ │ model=model, │
[rank24]: │ 1931 │ │ │ optim_input=None, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1319 in _optim_state_dict_to_load_impl │
[rank24]: │ │
[rank24]: │ 1316 │ │ │
[rank24]: │ 1317 │ │ if rank0_only and dist.get_rank(group) > 0: │
[rank24]: │ 1318 │ │ │ optim_state_dict = {} │
[rank24]: │ ❱ 1319 │ │ sharded_osd = _flatten_optim_state_dict( │
[rank24]: │ 1320 │ │ │ optim_state_dict, │
[rank24]: │ 1321 │ │ │ model=model, │
[rank24]: │ 1322 │ │ │ use_orig_params=use_orig_params, │
[rank24]: │ │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py:461 │
[rank24]: │ in _flatten_optim_state_dict │
[rank24]: │ │
[rank24]: │ 458 │ │
[rank24]: │ 459 │ # Construct the "state" part │
[rank24]: │ 460 │ flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} │
[rank24]: │ ❱ 461 │ unflat_osd_state = unflat_osd["state"] │
[rank24]: │ 462 │ all_state_keys = set(unflat_osd_state.keys()) │
[rank24]: │ 463 │ │
[rank24]: │ 464 │ for param, fqns in param_to_fqns.items(): │
[rank24]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank24]: KeyError: 'state'
With prior version of llm-foundry I didn't have this issue (albeit I was using FULL_SHARD
strategy), so I tried to change to the old Composer code for resuming the optimizer, from version <0.22.0
but it doesn't work and I get a different error:
[rank118]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:786 in <module> │
[rank118]: │ │
[rank118]: │ 783 │ cfg = om.merge(yaml_cfg, cli_cfg) │
[rank118]: │ 784 │ om.resolve(cfg) │
[rank118]: │ 785 │ assert isinstance(cfg, DictConfig) │
[rank118]: │ ❱ 786 │ main(cfg) │
[rank118]: │ 787 │
[rank118]: │ │
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:717 in main │
[rank118]: │ │
[rank118]: │ 714 │ │
[rank118]: │ 715 │ # Build the Trainer │
[rank118]: │ 716 │ log.info('Building trainer...') │
[rank118]: │ ❱ 717 │ trainer = Trainer( │
[rank118]: │ 718 │ │ run_name=run_name, │
[rank118]: │ 719 │ │ seed=seed, │
[rank118]: │ 720 │ │ model=model, │
[rank118]: │ │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in │
[rank118]: │ __init__ │
[rank118]: │ │
[rank118]: │ 1712 │ │ │ │ if wandb.run is None: │
[rank118]: │ 1713 │ │ │ │ │ load_object_store.init(self.state, self.logger) │
[rank118]: │ 1714 │ │ │ _, _, parsed_load_path = parse_uri(load_path) │
[rank118]: │ ❱ 1715 │ │ │ self._rng_state = checkpoint.load_checkpoint( │
[rank118]: │ 1716 │ │ │ │ state=self.state, │
[rank118]: │ 1717 │ │ │ │ logger=self.logger, │
[rank118]: │ 1718 │ │ │ │ path=parsed_load_path, │
[rank118]: │ │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:558 in │
[rank118]: │ load_checkpoint │
[rank118]: │ │
[rank118]: │ 555 │ dist.all_reduce(max_step_to_resume_from, reduce_operation='MAX') │
[rank118]: │ 556 │ dist.all_reduce(min_step_to_resume_from, reduce_operation='MIN') │
[rank118]: │ 557 │ if max_step_to_resume_from.data != min_step_to_resume_from.data: │
[rank118]: │ ❱ 558 │ │ raise RuntimeError( │
[rank118]: │ 559 │ │ │ textwrap.dedent( │
[rank118]: │ 560 │ │ │ │ f'Timestamp mismatch error: batch to resume from {ste │
[rank118]: │ 561 │ │ │ │ 'This usually occurs when at least one rank fails to │
[rank118]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank118]: RuntimeError: Timestamp mismatch error: batch to resume from 10000 is not the
[rank118]: same on all ranks. This usually occurs when at least one rank fails to save the
[rank118]: last checkpoint while using sharded checkpointing + autoresume. Please manually
[rank118]: resume by disabling autoresume and explicitly setting load_path to the most
[rank118]: recent checkpoints that all ranks have saved. E.g. for the 10th batch: trainer =
[rank118]: Trainer(autoresume=False, load_path="/path/to/checkpoint/ba10-rank{rank}.pt",
[rank118]: ...). Remember to keep the {rank} placeholder!
Are you experiencing a similar issue? Or do you have any hints?
I also have 0.8.0
and HYBRID_SHARD
@Riccorl, it seems like your problem is separate from @germanjke's. can you file a new issue with some more information like:
@germanjke can you try using freezing layers as part of the Optimizer, rather the Composer layer freezing algorithm. Freezing via the optimizer is more well-tested. For example,
optimizer:
lr: <your_learning_rate>
name: decoupled_adamw
disable_grad: ^((?!(Wqkv|out_proj)).)*$ # Regex which disables gradients except for attention and out_proj
I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with
HYBRID_SHARD
I get the following error (KeyError: 'state'
):
@Riccorl I have identified this as a PyTorch issue and opened a bug report on their end + a PR to fix it
Hello,
I'm currently training LLaMA PRO. Initially, I expanded the model from 32 layers to 40 layers and proceeded to train only the newly added 8 layers (every fifth layer). Therefore, I froze 32 out of the 40 layers.
The training is going well and only the layers I need are trained.
But after following a hardware failure, I attempted to resume training using
load_path
, but I encountered an error:My
ep0-ba4500/.metadata
looks like this:Have you experienced similar issues?