aws-neuron / aws-neuron-samples

Example code for AWS Neuron SDK developers building inference and training applications
Other
101 stars 32 forks source link

Bug fix: use load_sharded_state_dict for NeuronZero1Optimizer #52

Open ravindra-flip opened 8 months ago

ravindra-flip commented 8 months ago
  1. optimizer states are checkpointed via load_sharded_state_dict.
  2. The current tries to load it from state_dictionary resulting in key error.
    Traceback (most recent call last):
    File "tp_zero1_llama2_7b_hf_pretrain.py", line 838, in <module>
    _mp_fn(0, args)
    File "tp_zero1_llama2_7b_hf_pretrain.py", line 711, in _mp_fn
    Traceback (most recent call last):
    File "tp_zero1_llama2_7b_hf_pretrain.py", line 838, in <module>
    train_llama(flags)
    File "tp_zero1_llama2_7b_hf_pretrain.py", line 538, in train_llama
    optimizer.load_state_dict(state_dict["optimizer"])
    KeyError: 'optimizer'
  3. With change I was able to resume training
    
    ............

Compiler status PASS LOG Thu Oct 12 22:12:27 2023 - (0, 101) step_loss : 1.0938 throughput : 0.66 global step = 101