microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.07k stars 1.04k forks source link

How to save the intermediate model? #654

Open liuaiting opened 1 year ago

liuaiting commented 1 year ago

The current implement only saves the model after all the epochs finishes.

DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

    if args.output_dir is not None:
        print_rank_0("saving model ...", args.global_rank)
        rlhf_engine.actor = convert_lora_to_linear_layer(rlhf_engine.actor)
        rlhf_engine.critic = convert_lora_to_linear_layer(rlhf_engine.critic)
        if args.enable_ema:
            rlhf_engine.actor_ema = convert_lora_to_linear_layer(
                rlhf_engine.actor_ema)

        if torch.distributed.get_rank() == 0:
            save_hf_format(rlhf_engine.actor,
                           tokenizer,
                           args,
                           sub_folder="actor")
            save_hf_format(rlhf_engine.critic,
                           tokenizer,
                           args,
                           sub_folder="critic")
            if args.enable_ema:
                save_hf_format(rlhf_engine.actor_ema,
                               tokenizer,
                               args,
                               sub_folder="actor_ema")

        if args.actor_zero_stage == 3:
            save_zero_three_model(rlhf_engine.actor,
                                  global_rank=args.global_rank,
                                  save_dir=os.path.join(
                                      args.output_dir, "actor"),
                                  zero_stage=args.actor_zero_stage)
            if args.enable_ema:
                save_zero_three_model(rlhf_engine.actor_ema,
                                      global_rank=args.global_rank,
                                      save_dir=os.path.join(
                                          args.output_dir, "actor_ema"),
                                      zero_stage=args.actor_zero_stage)
        if args.critic_zero_stage == 3:
            save_zero_three_model(rlhf_engine.critic,
                                  global_rank=args.global_rank,
                                  save_dir=os.path.join(
                                      args.output_dir, "critic"),
                                  zero_stage=args.critic_zero_stage)

How can we save the model weigths during some intermediate epochs?

https://github.com/microsoft/DeepSpeedExamples/issues/434

liuaiting commented 1 year ago

@yaozhewei