optimizer states are checkpointed via load_sharded_state_dict.
The current tries to load it from state_dictionary resulting in key error.
Traceback (most recent call last):
File "tp_zero1_llama2_7b_hf_pretrain.py", line 838, in <module>
_mp_fn(0, args)
File "tp_zero1_llama2_7b_hf_pretrain.py", line 711, in _mp_fn
Traceback (most recent call last):
File "tp_zero1_llama2_7b_hf_pretrain.py", line 838, in <module>
train_llama(flags)
File "tp_zero1_llama2_7b_hf_pretrain.py", line 538, in train_llama
optimizer.load_state_dict(state_dict["optimizer"])
KeyError: 'optimizer'
With change I was able to resume training
............
Compiler status PASS
LOG Thu Oct 12 22:12:27 2023 - (0, 101) step_loss : 1.0938 throughput : 0.66
global step = 101
load_sharded_state_dict
.Compiler status PASS LOG Thu Oct 12 22:12:27 2023 - (0, 101) step_loss : 1.0938 throughput : 0.66 global step = 101