Open Orion-Zheng opened 7 months ago
Now I want to resume from this checkpoint with *1 node with 4 A100 40GB**, and the error below occurred. I guess it may related to the different checkpoint format(e.g. the RNG states and model/optim states). Is there any method to consolidate the checkpoint format?
@Orion-Zheng, thanks for reporting this issue. It is unclear to me that checkpointing consolidation is needed here since it seems you are checkpoint was saved with 4GPUs and resuming with 4GPUs. In other words, no change in number of GPUs between saving and resumption. Is that correct?
Now I want to resume from this checkpoint with *1 node with 4 A100 40GB**, and the error below occurred. I guess it may related to the different checkpoint format(e.g. the RNG states and model/optim states). Is there any method to consolidate the checkpoint format?
@Orion-Zheng, thanks for reporting this issue. It is unclear to me that checkpointing consolidation is needed here since it seems you are checkpoint was saved with 4GPUs and resuming with 4GPUs. In other words, no change in number of GPUs between saving and resumption. Is that correct?
Thank you for the timely reply! π I am not very familiar with the components of deepspeed checkpoint. But I think the rank_0 and rank_1 means partitions on different nodeπ€I guess although the total number of gpu is the same(4 GPU), but the checkpoint of 2 node 2 gpu is different from 1 node 4 gpu. Am I correct?
If i am correct, can I use the ds_to_universal.py to convert my previous ZeRO3 checkpoint to a universal format so I can resume from it with other GPU settings?
I noticed there is an open issue on ds_to_universal.py
, https://github.com/microsoft/DeepSpeed/issues/5283 so I am not sure if this issue has been fixed and works for ZeRO 3 checkpoint(previous issue is for ZeRO 2 ckpt). And information will be very appreciated!
@Orion-Zheng, deepspeed checkpoint is not aware of node-level information. What matters is parallel dimensions such as data parallel, pipeline parallel, and tensor parallel. So, the checkpoints from 2 nodes 2 GPUs should be the same as 1 node 4 GPUs. I have two questions:
@Orion-Zheng, also can you share the log of the run that saved the checkpoint?
@tjruwase Thank you for the information!π
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_config_file: config/deepspeed_config/stage3_offload.json
deepspeed_multinode_launcher: standard
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_process_ip: 192.108.1.3
main_process_port: 9999
main_training_function: main
num_machines: 2
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Deepspeed Config
{
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"stage3_gather_16bit_weights_on_model_save": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"zero_allow_untested_optimizer": true,
"gradient_clipping": "auto",
"wall_clock_breakdown": false
}
For the log you mentioned, I am not sure which one you refer to because the Huggingface Trainer seems to only print loss during training, maybe I should set verbose level somewhereπ€I will find it later.
Given that deepspeed checkpoint is not aware of node-level information
, I am curious why there are two shards of optim/model/rng states in my checkpoint. Considering I used 2 nodes * 2 gpus, I think the possible number should be 4 shards(equal to gpu number) or 1 shard (totally not aware of node-level and gpu-level information). Is this because stage3_gather_16bit_weights_on_model_save=True
in my deepspeed config? Probably the Deepspeed aggregate the optim/model/rng state on each node and save to disk?
rng_state_0.pth
rng_state_1.pth
global_step100/
β bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
β bf16_zero_P-rank 1 mp_rank_00_optim_states.pt
β zero_pp_rank_0_mp_rank_00_model_states .pt
β zero_pp_rank_1_mp_rank_00_model_states.pt
By the way, I also hope to know if I want to resume from this ZeRO 3 checkpoint with different GPU numbers, say 3 GPUs. Does ds_to_universal.py
help? Is it working well for ZeRO 3 ckpt without bug from https://github.com/microsoft/DeepSpeed/issues/5283 ?π
I find the 1 Node 4 GPUs checkpoint structure looks like this, which is different from 2 Node 2 GPU
(gpu_dev) (gpu_dev) bash-4.4$ tree .
.
βββ config.json
βββ generation_config.json
βββ global_step3
β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
β βββ zero_pp_rank_0_mp_rank_00_model_states.pt
β βββ zero_pp_rank_1_mp_rank_00_model_states.pt
β βββ zero_pp_rank_2_mp_rank_00_model_states.pt
β βββ zero_pp_rank_3_mp_rank_00_model_states.pt
βββ latest
βββ model.safetensors
βββ rng_state_0.pth
βββ rng_state_1.pth
βββ rng_state_2.pth
βββ rng_state_3.pth
βββ scheduler.pt
βββ trainer_state.json
βββ training_args.bin
βββ zero_to_fp32.py
Hello, I tried to use ds_to_universal.py
to convert the deepspeed checkpoint but the error below occurred.
python ds_to_universal.py --input_folder experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3/ \
--output_folder experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3_merge \
--num_extract_workers 10 \
--num_merge_workers 10 \
[2024-04-15 01:43:34,022] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-devel package with yum
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
[WARNING] using untested triton version (2.2.0), only 1.0.0 is known to be compatible
args = Namespace(input_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3', output_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3_merge', num_extract_workers=10, num_merge_workers=10, keep_temp_folder=False, strict=True)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3 to Universal checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3_merge
Traceback (most recent call last):
File "/scratch/users/nus/e0792473/EfficientVocabExtend/dist_env_tools/ds_to_universal.py", line 363, in <module>
main(args)
File "/scratch/users/nus/e0792473/EfficientVocabExtend/dist_env_tools/ds_to_universal.py", line 319, in main
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
File "/home/users/nus/e0792473/miniconda3/envs/gpu_dev/lib/python3.10/site-packages/deepspeed/checkpoint/deepspeed_checkpoint.py", line 40, in __init__
self._validate_folder(dir, pipeline_parallel)
File "/home/users/nus/e0792473/miniconda3/envs/gpu_dev/lib/python3.10/site-packages/deepspeed/checkpoint/deepspeed_checkpoint.py", line 292, in _validate_folder
assert len(
AssertionError: experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-15-011900/checkpoint-3 seems a bogus DeepSpeed checkpoint folder: Cannot find mp_rank_* files in there.
I find the 1 Node 4 GPUs checkpoint structure looks like this, which is different from 2 Node 2 GPU
(gpu_dev) (gpu_dev) bash-4.4$ tree . . βββ config.json βββ generation_config.json βββ global_step3 β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt β βββ zero_pp_rank_0_mp_rank_00_model_states.pt β βββ zero_pp_rank_1_mp_rank_00_model_states.pt β βββ zero_pp_rank_2_mp_rank_00_model_states.pt β βββ zero_pp_rank_3_mp_rank_00_model_states.pt
This confirms that the 2 2 run was saving in local folder not a global (e.g., nfs) folder. With zero3, we have distributed checkpoints where each rank will save a pair of `zero_pp.ptand
bf16_.pt` corresponding to its rank. This is why you are seeing half of the files when you inspect the first node of your 2 2 run.
Can you inspect the checkpoint path in both nodes of your 2 * 2 run?
OH, I understand. Yes you are right! I just find another directory where the shards from rank 3 and 4 were store in tmp-checkpoint-*
directories. I will further investigate why this would happen. Probably because I wrote a custom Callback in Trainer to store the checkpoint. But some mechanisms make the saving not work well.
βββ tinyllama_expanded_frez_embed-2024-04-12-221505
β βββ checkpoint-100
β β βββ config.json
β β βββ generation_config.json
β β βββ global_step100
β β β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
β β β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
β β β βββ zero_pp_rank_0_mp_rank_00_model_states.pt
β β β βββ zero_pp_rank_1_mp_rank_00_model_states.pt
β β βββ latest
β β βββ model.safetensors
β β βββ rng_state_0.pth
β β βββ rng_state_1.pth
β β βββ scheduler.pt
β β βββ trainer_state.json
β β βββ training_args.bin
β β βββ zero_to_fp32.py
β βββ checkpoint-132
β βββ config.json
β βββ generation_config.json
β βββ global_step132
β β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
β β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
β β βββ zero_pp_rank_0_mp_rank_00_model_states.pt
β β βββ zero_pp_rank_1_mp_rank_00_model_states.pt
β βββ latest
β βββ model.safetensors
β βββ rng_state_0.pth
β βββ rng_state_1.pth
β βββ scheduler.pt
β βββ trainer_state.json
β βββ training_args.bin
β βββ zero_to_fp32.py
βββ tinyllama_expanded_frez_embed-2024-04-12-221513
β βββ tmp-checkpoint-100
β β βββ global_step100
β β β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
β β β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
β β β βββ zero_pp_rank_2_mp_rank_00_model_states.pt
β β β βββ zero_pp_rank_3_mp_rank_00_model_states.pt
β β βββ rng_state_2.pth
β β βββ rng_state_3.pth
β βββ tmp-checkpoint-132
β βββ global_step132
β β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
β β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
β β βββ zero_pp_rank_2_mp_rank_00_model_states.pt
β β βββ zero_pp_rank_3_mp_rank_00_model_states.pt
β βββ rng_state_2.pth
β βββ rng_state_3.pth
@tjruwase Thank you! Now I understand what you mean.
May I ask how to convert a 4-GPU checkpoint to a universal checkpoint so I can continue training with other gpu numbers?πThis is the Google Drive link to my 4-GPU checkpoint.
https://drive.google.com/drive/folders/1efsIy9CcI_dTL41AtxdpcCJnSF1dn_uX?usp=sharing
I tried the ds_to_universal.py
but it seems doesn't work(xxx seems a bogus DeepSpeed checkpoint folder: Cannot find mp_rank_* files in there.
). Could you help me with this?
python ds_to_universal.py --input_folder my_4_gpu_ckpt \
--output_folder my_4_gpu_ckpt_merge \
--num_extract_workers 10 \
--num_merge_workers 10 \
OH, I understand. Yes you are right! I just find another directory where the shards from rank 3 and 4 were store in
tmp-checkpoint-*
directories. I will further investigate why this would happen. Probably because I wrote a custom Callback in Trainer to store the checkpoint. But some mechanisms make the saving not work well.βββ tinyllama_expanded_frez_embed-2024-04-12-221505 β βββ checkpoint-100 β β βββ config.json β β βββ generation_config.json β β βββ global_step100 β β β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt β β β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt β β β βββ zero_pp_rank_0_mp_rank_00_model_states.pt β β β βββ zero_pp_rank_1_mp_rank_00_model_states.pt β β βββ latest β β βββ model.safetensors β β βββ rng_state_0.pth β β βββ rng_state_1.pth β β βββ scheduler.pt β β βββ trainer_state.json β β βββ training_args.bin β β βββ zero_to_fp32.py β βββ checkpoint-132 β βββ config.json β βββ generation_config.json β βββ global_step132 β β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt β β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt β β βββ zero_pp_rank_0_mp_rank_00_model_states.pt β β βββ zero_pp_rank_1_mp_rank_00_model_states.pt β βββ latest β βββ model.safetensors β βββ rng_state_0.pth β βββ rng_state_1.pth β βββ scheduler.pt β βββ trainer_state.json β βββ training_args.bin β βββ zero_to_fp32.py βββ tinyllama_expanded_frez_embed-2024-04-12-221513 β βββ tmp-checkpoint-100 β β βββ global_step100 β β β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt β β β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt β β β βββ zero_pp_rank_2_mp_rank_00_model_states.pt β β β βββ zero_pp_rank_3_mp_rank_00_model_states.pt β β βββ rng_state_2.pth β β βββ rng_state_3.pth β βββ tmp-checkpoint-132 β βββ global_step132 β β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt β β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt β β βββ zero_pp_rank_2_mp_rank_00_model_states.pt β β βββ zero_pp_rank_3_mp_rank_00_model_states.pt β βββ rng_state_2.pth β βββ rng_state_3.pth
Oh I think I know the reason. It was my fault. I start distributed training by manually run the same script on two nodes. In my script, the checkpoint path will use the real timestamp as the suffix of the ckpt path. Because two scripts didn't launch exactly at the same time, they save checkpoint to different directories.
@Orion-Zheng, I am glad the original mystery is now resolved. For your second question, unfortunately zero3 is not yet supported. Is it possible for you to use zero2? I think zero2 should be adequate for your model size.
@minjiaz, @xylian86, @zhangninja FYI
@Orion-Zheng, I am glad the original mystery is now resolved. For your second question, unfortunately zero3 is not yet supported. Is it possible for you to use zero2? I think zero2 should be adequate for your model size.
@minjiaz, @xylian86, @zhangninja FYI
Thank you! Yes, the ZeRO-2 code does run successfully. But in my case, when using a large batch size, I can only use ZeRO-3. Probably I have to wait for the ZeRO-3 supportπ
By the way, may I ask what's the difference between ZeRO-2 and ZeRO-3 checkpoint? According the paper, the ZeRO-3 further shards model parameters compared to ZeRO-2. But in my ZeRO-3 checkpoint, the model parameters have been merged into one model.safetensors
shard (I guess it's due to stage3_gather_16bit_weights_on_model_save = True
). So I think my checkpoint should be very similar to ZeRO-2 checkpoint?π€Please correct me if I am wrong
@Orion-Zheng, with zero2 checkpoints there is a single mp_rank*
checkpoint file containing unshared parameter information (fp16/bf16 and layers) in addition to the 4 bf16_zero*
files. Whereas for zero3, there are 4 zero_pp*
files corresponding to the sharded parameter information. So, zero2 and zero3 checkpoints are not compatible.
If zero2 is failing on large batch size, a more appropriate solution is gradient checkpointing.
Oh yes, good idea! I will try ZeRO-2 + Gradient Checkpointing later :) Thank you for your help!
@tjruwase Hiπ³sorry to bother again. I tried ZeRO-2 and got a ZeRO-2 checkpoint. But it seems the Accelerate+Deepspeed checkpoint structure is a bit different from Universal Checkpoint examples in Megatron-Deepspeed. So some errors occurred.
args = Namespace(input_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3', output_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal', num_extract_workers=10, num_merge_workers=10, keep_temp_folder=False, strict=True)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3 to Universal checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal
Traceback (most recent call last):
File "dist_env_tools/ds_to_universal.py", line 363, in <module>
main(args)
File "dist_env_tools/ds_to_universal.py", line 320, in main
_check_for_required_state(ds_checkpoint)
File "dist_env_tools/ds_to_universal.py", line 311, in _check_for_required_state
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
AssertionError: Required universal_checkpoint_info state is missing in checkpoint. Verify that client creates this state.
This is my ZeRO-2 checkpoint's structure and this is the Google Drive link.
checkpoint-3/
βββ config.json
βββ generation_config.json
βββ global_step3
β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
β βββ mp_rank_00_model_states.pt
βββ latest
βββ model.safetensors
βββ rng_state_0.pth
βββ rng_state_1.pth
βββ rng_state_2.pth
βββ rng_state_3.pth
βββ scheduler.pt
βββ trainer_state.json
βββ training_args.bin
βββ zero_to_fp32.py
Although I have not idea how to rectify it, I think for you guys it only takes one look to see how to make it compatible with the ZeRO-2 format in my case.π Any help will be very appreciated!
Besides, I am also curious about how to deal with multiple rng_state_*.pth
if I resume training with different gpus?π€I think these seem unmergable.
@tjruwase Hiπ³sorry to bother again. I tried ZeRO-2 and got a ZeRO-2 checkpoint. But it seems the Accelerate+Deepspeed checkpoint structure is a bit different from Universal Checkpoint examples in Megatron-Deepspeed. So some errors occurred.
args = Namespace(input_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3', output_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal', num_extract_workers=10, num_merge_workers=10, keep_temp_folder=False, strict=True) Convert DeepSpeed Checkpoint to Universal Checkpoint Converting DeepSpeed checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3 to Universal checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal Traceback (most recent call last): File "dist_env_tools/ds_to_universal.py", line 363, in <module> main(args) File "dist_env_tools/ds_to_universal.py", line 320, in main _check_for_required_state(ds_checkpoint) File "dist_env_tools/ds_to_universal.py", line 311, in _check_for_required_state assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' AssertionError: Required universal_checkpoint_info state is missing in checkpoint. Verify that client creates this state.
This is my ZeRO-2 checkpoint's structure and this is the Google Drive link.
checkpoint-3/ βββ config.json βββ generation_config.json βββ global_step3 β βββ bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt β βββ bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt β βββ mp_rank_00_model_states.pt βββ latest βββ model.safetensors βββ rng_state_0.pth βββ rng_state_1.pth βββ rng_state_2.pth βββ rng_state_3.pth βββ scheduler.pt βββ trainer_state.json βββ training_args.bin βββ zero_to_fp32.py
Although I have not idea how to rectify it, I think for you guys it only takes one look to see how to make it compatible with the ZeRO-2 format in my case.π Any help will be very appreciated! Besides, I am also curious about how to deal with multiple
rng_state_*.pth
if I resume training with different gpus?π€I think these seem unmergable.
yes, i also find this , but i use transformers trainer call deepspeedοΌand get the same file as you , have you reslove this problemοΌ
Besides, I am also curious about how to deal with multiple
rng_state_*.pth
if I resume training with different gpus?π€I think these seem unmergable.
Since you are doing data parallel training, those should be duplicates and you need only one. Things will get interesting with model parallel training.
@tjruwase Hiπ³sorry to bother again. I tried ZeRO-2 and got a ZeRO-2 checkpoint. But it seems the Accelerate+Deepspeed checkpoint structure is a bit different from Universal Checkpoint examples in Megatron-Deepspeed. So some errors occurred.
@Orion-Zheng, you are hitting this error because some work is needed to port Universal Checkpointing to Accelerate. Currently, we have only ported to Megatron-DeepSpeed. However, in this case, can you confirm that you are planning to change the numbers of GPUs in your training?
@Orion-Zheng, are you still having this issue?
@tjruwase is this issue solved? I try to continue training the model switching from 4 nodes to 2 nodes and encounter the same problem using huggingface trainer.
args = Namespace(input_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3', output_folder='experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal', num_extract_workers=10, num_merge_workers=10, keep_temp_folder=False, strict=True)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3 to Universal checkpoint in experiment_ckpts/tinyllama_expanded_frez_embed-2024-04-16-010251/checkpoint-3_universal
Traceback (most recent call last):
File "dist_env_tools/ds_to_universal.py", line 363, in <module>
main(args)
File "dist_env_tools/ds_to_universal.py", line 320, in main
_check_for_required_state(ds_checkpoint)
File "dist_env_tools/ds_to_universal.py", line 311, in _check_for_required_state
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
AssertionError: Required universal_checkpoint_info state is missing in checkpoint. Verify that client creates this state.
@xylian86, can you please help with this?
@tjruwase Yes, for sure.
@xiyang-aads-lilly Could you please try 1. Install the latest DeepSpeed version (v0.14.5 Patch release) 2. Add the argument inject_missing_state
when you run the conversion?
@Orion-Zheng, are you still having this issue?
yes
Describe the bug Hello, I encountered a problem when trying to resume from a previous checkpoint when I used Transformers Trainer + ZeRO 3 strategy to train a TinyLlama. My previous run is conducted on *2 nodes 2 A100 40GB on each node**. The structure of the previous checkpoint is showed below.
Now I want to resume from this checkpoint with *1 node with 4 A100 40GB**, and the error below occurred. I guess it may related to the different checkpoint format(e.g. the RNG states and model/optim states). Is there any method to consolidate the checkpoint format? Any help will be really appreciated! For us students using HPC clusters, sometimes it's hard for us to always get the same GPU numbers or specifications. Thanks in advance!
To Reproduce Steps to reproduce the behavior:
Expected behavior A clear and concise description of what you expected to happen.
ds_report output Please run
ds_report
to give us details about your setup.Screenshots
System info (please complete the following information):
Launcher context Are you launching your experiment with the
deepspeed
launcher, MPI, or something else?Docker context Are you using a specific docker image that you can share?
Additional context Add any other context about the problem here.