mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.84k stars 503 forks source link

LLaMA PRO training resume problem #1231

Open germanjke opened 1 month ago

germanjke commented 1 month ago

Hello,

I'm currently training LLaMA PRO. Initially, I expanded the model from 32 layers to 40 layers and proceeded to train only the newly added 8 layers (every fifth layer). Therefore, I froze 32 out of the 40 layers.

layer_freezing: 
    layer_names: [ 
    'model._fsdp_wrapped_module.model.layers.36._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.27._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.32._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.35._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.37._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.28._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.30._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.38._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.33._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.31._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.26._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param'
    ]

The training is going well and only the layers I need are trained.

But after following a hardware failure, I attempted to resume training using load_path, but I encountered an error:

[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 14)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 15)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'

My ep0-ba4500/.metadata looks like this:

���%torch.distributed.checkpoint.metadata��Metadata���)��}�(�state_dict_metadata�}�(�+state.model.model.model.embed_tokens.weight�h�TensorStorageMetadata���)��}�(�
properties�h�TensorProperties���)��(�torch��float32����torch.serialization��_get_layout����
torch.strided���R��h�_MEM_FORMAT_ENCODING���K��R��t�b�size��torch��Size���J�M����R��chunks�]�(h�ChunkStorageMetadata���)��}�(�offsets�h!KK����R��sizes�h!M�>M����R�ubh()��}�(h+h!M��K����R�h/h!M�>M����R�ubh()��}�(h+h!J�wK����R�h/h!M�>M����R�ubh()��}�(h+h!M�>K����R�h/h!M�>M����R�ubh()��}�(h+h!M@}K����R�h/h!M�>M����R�ubh()��}�(h+h!M�K����R�h/h!M�>M����R�ubh()��}�(h+h!J 9K����R�h/h!M�>M����R�ubh()��}�(h+h!J`�K����R�h/h!M�>M����R�ubeub�3state.model.model.model.layers.2.mlp.up_proj.weight�h )��}�(hh)��(hh�h�t�bhh!M8M����R�h%]�(h()��}�(h+h!KK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M*K����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M#K����R�h/h!MM����R�ubh()��}�(h+h!M1K����R�h/h!MM����R�ubeub�7state.model.model.model.layers.2.input_layernorm.weight�h )��}�(hh)��(hh�h�t�bhh!M����R�h%]�(h()��}�(h+h!K����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M
����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubeub�@state.model.model.model.layers.2.post_attention_layernorm.weight�h    )��}�(hh)��(hh�h�t�bhh!M����R�h%]�(h()��}�(h+h!K����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubh()��}�(h+h!M
����R�h/h!M����R�ubh()��}�(h+h!M����R�h/h!M����R�ubeub�8state.model.model.model.layers.3.self_attn.q_proj.weight�h    )��}�(hh)��(hh�h�t�bhh!MM����R�h%]�(h()��}�(h+h!KK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M
K����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubeub�8state.model.model.model.layers.3.self_attn.o_proj.weight�h  )��}�(hh)��(hh�h�t�bhh!MM����R�h%]�(h()��}�(h+h!KK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M
K����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubeub�3state.model.model.model.layers.3.mlp.up_proj.weight�h   )��}�(hh)��(hh�h�t�bhh!M8M����R�h%]�(h()��}�(h+h!KK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M*K����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!MK����R�h/h!MM����R�ubh()��}�(h+h!M#K����R�h/h!MM����R�ubh()��}�(h+h!M1K����R�h/h!MM����R�ubeub�7state.model.model.model.layers.3.input_layernorm.weight�h )��}�(hh)��
etc...

Have you experienced similar issues?

dakinggg commented 1 month ago

We haven't thoroughly tested with layer freezing + FSDP, so possible something complicated is going on here. However, we have seen this error when you try to load a checkpoint that doesn't have optimizer state. So it is possible that loading a checkpoint only containing optimizer state for some of the parameters does not work properly.

Riccorl commented 1 month ago

I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with HYBRID_SHARD I get the following error (KeyError: 'state'):

[rank24]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:786 in <module>                                         │
[rank24]: │                                                                              │
[rank24]: │   783 │   cfg = om.merge(yaml_cfg, cli_cfg)                                  │
[rank24]: │   784 │   om.resolve(cfg)                                                    │
[rank24]: │   785 │   assert isinstance(cfg, DictConfig)                                 │
[rank24]: │ ❱ 786 │   main(cfg)                                                          │
[rank24]: │   787                                                                        │
[rank24]: │                                                                              │
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:717 in main                                             │
[rank24]: │                                                                              │
[rank24]: │   714 │                                                                      │
[rank24]: │   715 │   # Build the Trainer                                                │
[rank24]: │   716 │   log.info('Building trainer...')                                    │
[rank24]: │ ❱ 717 │   trainer = Trainer(                                                 │
[rank24]: │   718 │   │   run_name=run_name,                                             │
[rank24]: │   719 │   │   seed=seed,                                                     │
[rank24]: │   720 │   │   model=model,                                                   │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in        │
[rank24]: │ __init__                                                                     │
[rank24]: │                                                                              │
[rank24]: │   1712 │   │   │   │   if wandb.run is None:                                 │
[rank24]: │   1713 │   │   │   │   │   load_object_store.init(self.state, self.logger)   │
[rank24]: │   1714 │   │   │   _, _, parsed_load_path = parse_uri(load_path)             │
[rank24]: │ ❱ 1715 │   │   │   self._rng_state = checkpoint.load_checkpoint(             │
[rank24]: │   1716 │   │   │   │   state=self.state,                                     │
[rank24]: │   1717 │   │   │   │   logger=self.logger,                                   │
[rank24]: │   1718 │   │   │   │   path=parsed_load_path,                                │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:531 in        │
[rank24]: │ load_checkpoint                                                              │
[rank24]: │                                                                              │
[rank24]: │    528 │   │   │   │   │   fsdp_sharded_state_dict_enabled=state.fsdp_sharde │
[rank24]: │    529 │   │   │   │   │   deepspeed_sharded_checkpoint=is_model_deepspeed(s │
[rank24]: │    530 │   │   │   │   )                                                     │
[rank24]: │ ❱  531 │   │   │   │   rng_state_dicts = _restore_checkpoint(                │
[rank24]: │    532 │   │   │   │   │   state,                                            │
[rank24]: │    533 │   │   │   │   │   logger,                                           │
[rank24]: │    534 │   │   │   │   │   composer_states_filepath,                         │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:999 in        │
[rank24]: │ _restore_checkpoint                                                          │
[rank24]: │                                                                              │
[rank24]: │    996 │   │   │   algorithm_passes=algorithm_passes,                        │
[rank24]: │    997 │   │   )                                                             │
[rank24]: │    998 │   if not load_weights_only:                                         │
[rank24]: │ ❱  999 │   │   state.load_state_dict(                                        │
[rank24]: │   1000 │   │   │   state_dict['state'],                                      │
[rank24]: │   1001 │   │   │   logger,                                                   │
[rank24]: │   1002 │   │   │   exclude_algorithms=exclude_algorithms,                    │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1418 in             │
[rank24]: │ load_state_dict                                                              │
[rank24]: │                                                                              │
[rank24]: │   1415 │   │   │   if attribute_name == 'dataset_state':                     │
[rank24]: │   1416 │   │   │   │   self._load_dataset_state(serialized_value)            │
[rank24]: │   1417 │   │   │   elif attribute_name == 'optimizers':                      │
[rank24]: │ ❱ 1418 │   │   │   │   self.load_optim_state(state)                          │
[rank24]: │   1419 │   │   │   elif attribute_name == 'train_metrics':                   │
[rank24]: │   1420 │   │   │   │   # Get current metrics object and populate each metric │
[rank24]: │   1421 │   │   │   │   # in serialization with serialized data via load_stat │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1331 in             │
[rank24]: │ load_optim_state                                                             │
[rank24]: │                                                                              │
[rank24]: │   1328 │   │   │   │   # errors) before discarding the output. Accordingly,  │
[rank24]: │   1329 │   │   │   │   # See: https://github.com/pytorch/pytorch/issues/1251 │
[rank24]: │   1330 │   │   │   │   optim_state_dict = MagicMock() if optim_state_dict is │
[rank24]: │ ❱ 1331 │   │   │   │   set_optimizer_state_dict(                             │
[rank24]: │   1332 │   │   │   │   │   model=self.model,                                 │
[rank24]: │   1333 │   │   │   │   │   optimizers=optimizer,                             │
[rank24]: │   1334 │   │   │   │   │   optim_state_dict=optim_state_dict,                │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/mosaic_fsdp_utils.py:719  │
[rank24]: │ in set_optimizer_state_dict                                                  │
[rank24]: │                                                                              │
[rank24]: │   716 │   │   │   info = _verify_options(model, optimizers, optim_only=True, │
[rank24]: │   717 │   │   │                                                              │
[rank24]: │   718 │   │   │   _verify_state_dict({}, optim_state_dict, info)             │
[rank24]: │ ❱ 719 │   │   │   _load_optim_state_dict(model, optimizers, optim_state_dict │
[rank24]: │   720                                                                        │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py │
[rank24]: │ :616 in _load_optim_state_dict                                               │
[rank24]: │                                                                              │
[rank24]: │    613 │   │   │   │   │   │   osd_state[k.replace(fqn, fqn_with_compiler)]  │
[rank24]: │    614 │   │   │                                                             │
[rank24]: │    615 │   │   │   with info.fsdp_context():                                 │
[rank24]: │ ❱  616 │   │   │   │   optim_state_dict = FSDP.optim_state_dict_to_load(     │
[rank24]: │    617 │   │   │   │   │   model, optim, optim_state_dict                    │
[rank24]: │    618 │   │   │   │   )                                                     │
[rank24]: │    619                                                                       │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1928 in optim_state_dict_to_load                                 │
[rank24]: │                                                                              │
[rank24]: │   1925 │   │   │   │   Default: ``None``)                                    │
[rank24]: │   1926 │   │   """                                                           │
[rank24]: │   1927 │   │   state_dict_settings = FullyShardedDataParallel.get_state_dict │
[rank24]: │ ❱ 1928 │   │   result = FullyShardedDataParallel._optim_state_dict_to_load_i │
[rank24]: │   1929 │   │   │   optim_state_dict=optim_state_dict,                        │
[rank24]: │   1930 │   │   │   model=model,                                              │
[rank24]: │   1931 │   │   │   optim_input=None,                                         │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1319 in _optim_state_dict_to_load_impl                           │
[rank24]: │                                                                              │
[rank24]: │   1316 │   │                                                                 │
[rank24]: │   1317 │   │   if rank0_only and dist.get_rank(group) > 0:                   │
[rank24]: │   1318 │   │   │   optim_state_dict = {}                                     │
[rank24]: │ ❱ 1319 │   │   sharded_osd = _flatten_optim_state_dict(                      │
[rank24]: │   1320 │   │   │   optim_state_dict,                                         │
[rank24]: │   1321 │   │   │   model=model,                                              │
[rank24]: │   1322 │   │   │   use_orig_params=use_orig_params,                          │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py:461 │
[rank24]: │  in _flatten_optim_state_dict                                                │
[rank24]: │                                                                              │
[rank24]: │    458 │                                                                     │
[rank24]: │    459 │   # Construct the "state" part                                      │
[rank24]: │    460 │   flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}        │
[rank24]: │ ❱  461 │   unflat_osd_state = unflat_osd["state"]                            │
[rank24]: │    462 │   all_state_keys = set(unflat_osd_state.keys())                     │
[rank24]: │    463 │                                                                     │
[rank24]: │    464 │   for param, fqns in param_to_fqns.items():                         │
[rank24]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank24]: KeyError: 'state'

With prior version of llm-foundry I didn't have this issue (albeit I was using FULL_SHARD strategy), so I tried to change to the old Composer code for resuming the optimizer, from version <0.22.0 but it doesn't work and I get a different error:

[rank118]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:786 in <module>                                         │
[rank118]: │                                                                              │
[rank118]: │   783 │   cfg = om.merge(yaml_cfg, cli_cfg)                                  │
[rank118]: │   784 │   om.resolve(cfg)                                                    │
[rank118]: │   785 │   assert isinstance(cfg, DictConfig)                                 │
[rank118]: │ ❱ 786 │   main(cfg)                                                          │
[rank118]: │   787                                                                        │
[rank118]: │                                                                              │
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:717 in main                                             │
[rank118]: │                                                                              │
[rank118]: │   714 │                                                                      │
[rank118]: │   715 │   # Build the Trainer                                                │
[rank118]: │   716 │   log.info('Building trainer...')                                    │
[rank118]: │ ❱ 717 │   trainer = Trainer(                                                 │
[rank118]: │   718 │   │   run_name=run_name,                                             │
[rank118]: │   719 │   │   seed=seed,                                                     │
[rank118]: │   720 │   │   model=model,                                                   │
[rank118]: │                                                                              │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in        │
[rank118]: │ __init__                                                                     │
[rank118]: │                                                                              │
[rank118]: │   1712 │   │   │   │   if wandb.run is None:                                 │
[rank118]: │   1713 │   │   │   │   │   load_object_store.init(self.state, self.logger)   │
[rank118]: │   1714 │   │   │   _, _, parsed_load_path = parse_uri(load_path)             │
[rank118]: │ ❱ 1715 │   │   │   self._rng_state = checkpoint.load_checkpoint(             │
[rank118]: │   1716 │   │   │   │   state=self.state,                                     │
[rank118]: │   1717 │   │   │   │   logger=self.logger,                                   │
[rank118]: │   1718 │   │   │   │   path=parsed_load_path,                                │
[rank118]: │                                                                              │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:558 in        │
[rank118]: │ load_checkpoint                                                              │
[rank118]: │                                                                              │
[rank118]: │    555 │   dist.all_reduce(max_step_to_resume_from, reduce_operation='MAX')  │
[rank118]: │    556 │   dist.all_reduce(min_step_to_resume_from, reduce_operation='MIN')  │
[rank118]: │    557 │   if max_step_to_resume_from.data != min_step_to_resume_from.data:  │
[rank118]: │ ❱  558 │   │   raise RuntimeError(                                           │
[rank118]: │    559 │   │   │   textwrap.dedent(                                          │
[rank118]: │    560 │   │   │   │   f'Timestamp mismatch error: batch to resume from {ste │
[rank118]: │    561 │   │   │   │   'This usually occurs when at least one rank fails to  │
[rank118]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank118]: RuntimeError: Timestamp mismatch error: batch to resume from 10000 is not the 
[rank118]: same on all ranks. This usually occurs when at least one rank fails to save the 
[rank118]: last checkpoint while using sharded checkpointing + autoresume. Please manually 
[rank118]: resume by disabling autoresume and explicitly setting load_path to the most 
[rank118]: recent checkpoints that all ranks have saved. E.g. for the 10th batch: trainer =
[rank118]: Trainer(autoresume=False, load_path="/path/to/checkpoint/ba10-rank{rank}.pt", 
[rank118]: ...). Remember to keep the {rank} placeholder!

Are you experiencing a similar issue? Or do you have any hints?

germanjke commented 1 month ago

I also have 0.8.0 and HYBRID_SHARD

eracah commented 1 month ago

@Riccorl, it seems like your problem is separate from @germanjke's. can you file a new issue with some more information like:

nik-mosaic commented 1 month ago

@germanjke can you try using freezing layers as part of the Optimizer, rather the Composer layer freezing algorithm. Freezing via the optimizer is more well-tested. For example,

optimizer:
    lr: <your_learning_rate>
    name: decoupled_adamw
    disable_grad: ^((?!(Wqkv|out_proj)).)*$ # Regex which disables gradients except for attention and out_proj
mvpatel2000 commented 3 weeks ago

I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with HYBRID_SHARD I get the following error (KeyError: 'state'):

@Riccorl I have identified this as a PyTorch issue and opened a bug report on their end + a PR to fix it