microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.07k stars 1.04k forks source link

ds_eval_config v.s. ds_config #653

Open SenZHANG-GitHub opened 1 year ago

SenZHANG-GitHub commented 1 year ago

when initializing reward and ref models in step 3 of deepspeed-chat, there are two kinds of deepspeed config files are used, i.e. ds_config and ds_eval_config. May I ask why we need to use two configs here and any suggestions on safely removing ds_eval_config? e.g.,

def _init_reward(self, critic_model_name_or_path):
        stime = log_init("Reward")
        # DS Config
        zero_stage = self.args.critic_zero_stage
        if zero_stage != 3:
            # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
            zero_stage = 0

        ds_config = get_eval_ds_config(offload=self.args.offload,
                                       stage=zero_stage)
        ds_config[
            'train_micro_batch_size_per_gpu'] = self.args.per_device_mini_train_batch_size
        ds_config[
            'train_batch_size'] = self.args.per_device_mini_train_batch_size * torch.distributed.get_world_size(
            ) * self.args.gradient_accumulation_steps

        #TODO(jeff): should not be needed, we should be able to use ds_config above
        #TODO(jeff): it means we never create the critic w. zero.init context if we are using ZeRO-3
        ds_eval_config = get_eval_ds_config(offload=False, stage=0)

        # Model
        reward_model = create_critic_model(
            model_name_or_path=critic_model_name_or_path,
            tokenizer=self.tokenizer,
            ds_config=ds_eval_config,
            num_padding_at_beginning=self.args.num_padding_at_beginning,
            rlhf_training=True)

        reward_engine, *_ = deepspeed.initialize(model=reward_model,
                                                 config=ds_config)
liziniu commented 1 year ago

I have the same concern. The same issue holds when initializing the critic. I guess we should use ds_config for the critic model and ds_eval_config for the reward model.

HeZez commented 1 year ago

I have the same confusion. Is there any progress on this issue?