allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.2k stars 392 forks source link

Resuming training on unsharded checkpoint #641

Open lecifire opened 5 days ago

lecifire commented 5 days ago

πŸ› Describe the bug

I tried resuming training on a previous unsharded checkpoint from step 1k and the training resumed with no initial issue however when it tried to save the sharded checkpoint i encountered a error as shown below wondering what is causing this issue? For context, the env/node number used are all the same.

Traceback (most recent call last): File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 345, in main(cfg) File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 316, in main trainer.fit() File "/workspace/OLMo/olmo/train.py", line 1153, in fit checkpointpath, = self.save_checkpoint(CheckpointType.sharded) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/workspace/OLMo/olmo/train.py", line 560, in save_checkpoint result = self.save_sharded_checkpoint() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/workspace/OLMo/olmo/train.py", line 468, in save_sharded_checkpoint result = self._save_checkpoint(checkpointer, CheckpointType.sharded) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/workspace/OLMo/olmo/train.py", line 428, in _save_checkpoint checkpointer.save_checkpoint( File "/workspace/OLMo/olmo/checkpoint.py", line 1000, in save_checkpoint "optim": FSDP.optim_state_dict(dist_model, optim), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict return FullyShardedDataParallel._optim_state_dict_impl( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl return _optim_state_dict( ^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict fsdp_osd_state = convert_fn( ^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params _gather_all_orig_param_state( File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state output_states = _allgather_orig_param_states( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states dtype, state_buffers = _convert_all_state_info( ^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1415, in _convert_all_state_info assert curr_scalar_tensor_value is None or torch.equal( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError: Rank 4 has different values for step: 1500.0. Other ranks: 500.0

Versions

.