microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.05k stars 1.03k forks source link

DeepSpeed-Chat: prefetch of layers during reward model forward pass leads to error during sample generation #337

Open adammoody opened 1 year ago

adammoody commented 1 year ago

When running step 3 with ZERO stage 3 enabled for both the actor and critic models, I get the following error (line numbers may be offset due to debug statements I've added):

File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
  main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
  out = trainer.generate_experience(prompts)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 103, in generate_experience
    seq = self._generate_sequence(prompts)
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
  seq = self.actor_model.module.generate(prompts,
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
  self.fuse_lora_weight()
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
    weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1

This happens because the weight.data shape does not match the tensor shape resulting from the lora matmul operation.

I am using a system with 4x 16GB V100 GPUs per node with DeepSpeed 0.9.1. I trained a 1.3b-param model in step 1 and 350m-param model in step 2.

My step 3 run command launches 4 processes on one node, binding one process per GPU:

cd training/step3_rlhf_finetuning
OUTPUT=${OUTPUTDIR}/step3-models/1.3b
mkdir -p $OUTPUT
ACTOR_MODEL_PATH=${OUTPUTDIR}/actor-models/1.3b
CRITIC_MODEL_PATH=${OUTPUTDIR}/reward-models/1.3b
ACTOR_ZERO_STAGE=3
CRITIC_ZERO_STAGE=3
jsrun -r 1 --tasks_per_rs 4 -c ALL_CPUS -g ALL_GPUS python3 main.py \
   --per_device_train_batch_size 4 \
   --per_device_mini_train_batch_size 4 \
   --inference_tp_size 1 \
   --max_answer_seq_len 256 \
   --max_prompt_seq_len 256 \
   --actor_model_name_or_path $ACTOR_MODEL_PATH \
   --critic_model_name_or_path $CRITIC_MODEL_PATH \
   --actor_zero_stage $ACTOR_ZERO_STAGE \
   --critic_zero_stage $CRITIC_ZERO_STAGE \
   --num_padding_at_beginning 1 \
   --gradient_accumulation_steps 1 \
   --deepspeed \
   --actor_lora_dim 128 \
   --enable_hybrid_engine \
   --actor_gradient_checkpointing \
   --critic_gradient_checkpointing \
   --output_dir $OUTPUT

After some debugging, I found that the above error arises because the GatheredParameters context does not gather all layers. If I print the tensor shape for each parameter of each layer immediately after GatheredParameters like so:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L238

                with GatheredParameters(non_active_layers):
                    if rank == 0:
                        for layer_id in range(len(self.layer_params)):
                            for p_id, p in enumerate(self.layer_params[layer_id]):
                                print("after gather layer_id", layer_id, p_id, p.shape, flush=True)
                    self._gather_latency = time.time() - self._t0

then I see the following output on the step just before the error:

nonactive all layers 931
after gather layer_id 0 0 torch.Size([0])
after gather layer_id 0 1 torch.Size([0])
after gather layer_id 0 2 torch.Size([0])
after gather layer_id 0 3 torch.Size([0])
after gather layer_id 0 4 torch.Size([0])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([0])
after gather layer_id 0 9 torch.Size([0])
after gather layer_id 0 10 torch.Size([0])
after gather layer_id 0 11 torch.Size([0])
after gather layer_id 0 12 torch.Size([0])
after gather layer_id 0 13 torch.Size([0])
after gather layer_id 0 14 torch.Size([0])
after gather layer_id 0 15 torch.Size([0])
after gather layer_id 1 0 torch.Size([2048])
after gather layer_id 1 1 torch.Size([2048])
after gather layer_id 1 2 torch.Size([2048])
after gather layer_id 1 3 torch.Size([2048])
after gather layer_id 1 4 torch.Size([8192, 2048])
after gather layer_id 1 5 torch.Size([8192])
after gather layer_id 1 6 torch.Size([2048, 8192])
after gather layer_id 1 7 torch.Size([2048])
after gather layer_id 1 8 torch.Size([2048, 2048])
after gather layer_id 1 9 torch.Size([2048])
after gather layer_id 1 10 torch.Size([2048, 2048])
after gather layer_id 1 11 torch.Size([2048])
after gather layer_id 1 12 torch.Size([2048, 2048])
after gather layer_id 1 13 torch.Size([2048])
after gather layer_id 1 14 torch.Size([2048, 2048])
after gather layer_id 1 15 torch.Size([2048])

Note that dimensions of the parameters in layer_id=0 are mostly all zero. On that steps that complete without an error, those parameters have non-zero shapes as shown below. The count of non_active_layers in 962 below vs 931 above.

nonactive all layers 962
after gather layer_id 0 0 torch.Size([2048])
after gather layer_id 0 1 torch.Size([2048])
after gather layer_id 0 2 torch.Size([2048])
after gather layer_id 0 3 torch.Size([2048])
after gather layer_id 0 4 torch.Size([8192, 2048])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([2048, 2048])
after gather layer_id 0 9 torch.Size([2048])
after gather layer_id 0 10 torch.Size([2048, 2048])
after gather layer_id 0 11 torch.Size([2048])
after gather layer_id 0 12 torch.Size([2048, 2048])
after gather layer_id 0 13 torch.Size([2048])
after gather layer_id 0 14 torch.Size([2048, 2048])
after gather layer_id 0 15 torch.Size([2048])

By adding the following lines for further details:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L234-L238

             else:
                from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
                rank = dist.get_rank(group=self.mp_group)

                non_active_layers = get_inactive_params(self.all_layers_params)
                if rank == 0:
                    print("nonactive layers", len(non_active_layers))
                    for lay_id, lay in enumerate(self.all_layers_params):
                        print("all layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)

               non_active_lora_params = get_inactive_params(self.all_lora_params)
                if rank == 0:
                    print("nonactive lora layers", len(non_active_lora_params))
                    for lay_id, lay in enumerate(self.all_lora_params):
                        print("lora layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)

                non_active_layers.extend(non_active_lora_params)

It seems that the 0-shape parameters are marked as "ds_status == ZeroParamStatus.INFLIGHT" before calling "GatheredParameters":

[2023-04-17 15:33:56,759] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32768, reducing to 16384
epoch: 0|step: 2|ppo_ep: 1|act_loss: nan|cri_loss: nan|unsuper_loss: 0.0
average reward score: 3.267578125
-------------------------------------------------------------------------------------
|E2E latency=17.17s |Gather latency=0.46s (2.70%) |Generate time=7.04s (41.02%) |Training time=6.82s (39.71%) |Others=3.31 (19.27%)|CurSamplesPerSec=0.93 |AvgSamplesPerSec=0.60
nonactive layers 651
all layers 0 True False ZeroParamStatus.INFLIGHT
all layers 1 True False ZeroParamStatus.INFLIGHT
all layers 2 True False ZeroParamStatus.INFLIGHT
all layers 3 True False ZeroParamStatus.INFLIGHT
all layers 4 True False ZeroParamStatus.INFLIGHT
all layers 5 True False ZeroParamStatus.INFLIGHT
all layers 6 True False ZeroParamStatus.INFLIGHT
all layers 7 True False ZeroParamStatus.INFLIGHT
all layers 8 True False ZeroParamStatus.INFLIGHT
all layers 9 True False ZeroParamStatus.INFLIGHT
all layers 10 True False ZeroParamStatus.INFLIGHT
all layers 11 True False ZeroParamStatus.INFLIGHT
all layers 12 True False ZeroParamStatus.INFLIGHT
all layers 13 True False ZeroParamStatus.INFLIGHT
all layers 14 True False ZeroParamStatus.INFLIGHT
all layers 15 True False ZeroParamStatus.INFLIGHT
all layers 16 True False ZeroParamStatus.INFLIGHT
all layers 17 True False ZeroParamStatus.INFLIGHT
all layers 18 True False ZeroParamStatus.INFLIGHT
all layers 19 True False ZeroParamStatus.INFLIGHT
all layers 20 True False ZeroParamStatus.INFLIGHT
all layers 21 True True ZeroParamStatus.NOT_AVAILABLE
all layers 22 True True ZeroParamStatus.NOT_AVAILABLE
all layers 23 True True ZeroParamStatus.NOT_AVAILABLE
all layers 24 True True ZeroParamStatus.NOT_AVAILABLE
all layers 25 True True ZeroParamStatus.NOT_AVAILABLE
all layers 26 True True ZeroParamStatus.NOT_AVAILABLE
all layers 27 True True ZeroParamStatus.NOT_AVAILABLE
all layers 28 True False ZeroParamStatus.INFLIGHT
all layers 29 True False ZeroParamStatus.INFLIGHT
all layers 30 True True ZeroParamStatus.NOT_AVAILABLE
all layers 31 True True ZeroParamStatus.NOT_AVAILABLE

<snip>

nonactive lora layers 280
lora layers 0 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 1 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 2 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 3 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 4 True False ZeroParamStatus.INFLIGHT
lora layers 5 True False ZeroParamStatus.INFLIGHT
lora layers 6 True False ZeroParamStatus.INFLIGHT
lora layers 7 True False ZeroParamStatus.INFLIGHT
lora layers 8 True False ZeroParamStatus.INFLIGHT
lora layers 9 True False ZeroParamStatus.INFLIGHT
lora layers 10 True False ZeroParamStatus.INFLIGHT
lora layers 11 True False ZeroParamStatus.INFLIGHT
lora layers 12 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 13 True True ZeroParamStatus.NOT_AVAILABLE

I think those parameters are marked as INFLIGHT because they have been prefetched. Adding some more debugging lines to print the stack at the point where the status is set to INFLIGHT:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/zero/partition_parameters.py#L873-L885

        def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = True) -> AllGatherCoalescedHandle:

            # fetches from nvme if the partition is not available and in nvme
            self._ensure_availability_of_partitioned_params(params)

            if self.world_size == 1:
                return _no_gather_coalesced(params)

            #for param in params:
            for p_id, param in enumerate(params):
                if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
                    raise RuntimeError(param.ds_summary())
                param.ds_status = ZeroParamStatus.INFLIGHT
                if dist.get_rank() == 0:
                    print(p_id, "INFLIGHT2")
                    if p_id > 20:
                        print(traceback.print_stack(file=sys.stdout))

I can see those layers are set to INFLIGHT here:

File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
  main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
  actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 180, in train_rlhf
  value = self.critic_model.forward_value(**batch,
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
  transformer_outputs = self.rwtranrsformer(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
  result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
  decoder_outputs = self.decoder(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
  layer_outputs = torch.utils.checkpoint.checkpoint(
File "/path/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
  return CheckpointFunction.apply(function, preserve, *args)
File "/path/site-packages/torch/utils/checkpoint.py", line 96, in forward
  outputs = run_function(*args)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
  return module(*inputs, output_attentions, None)
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
  result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
  hidden_states = self.activation_fn(hidden_states)
File "/path/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
  result = hook(self, input)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 366, in _pre_forward_module_hook
  self.pre_sub_module_forward_function(module)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 478, in pre_sub_module_forward_function
  param_coordinator.fetch_sub_module(sub_module)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 333, in fetch_sub_module
  self.__all_gather_params(params_to_prefetch)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 381, in __all_gather_params
  handle = partitioned_params[0].all_gather_coalesced(partitioned_params)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 878, in all_gather_coalesced
  print(traceback.print_stack(file=sys.stdout))

It seems that the layers are being prefetched during the call to the critic model forward pass:

https://github.com/microsoft/DeepSpeedExamples/blob/2aa7a31b8fdcb34b8ccdc554021a1f5789752ab3/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L174

They are still in INFLIGHT status when trying to generate a sample. The get_inactive_params function then only include params marked as NOT_AVAILABLE:

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/utils.py#L972-L975

Later, GatheredParameters may only consider params whose state is NOT_AVAILABLE:

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/zero/partition_parameters.py#L1058

Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should get_inactive_params include INFLIGHT params?

adammoody commented 1 year ago

A second question that came up while looking at this... it seems like the if conditions here might always be true:

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/hybrid_engine.py#L123

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/hybrid_engine.py#L136

Should it be lora_param instead of lora_params? Maybe change this to:

if len(lora_param) == 3:
tjruwase commented 1 year ago

Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should get_inactive_params include INFLIGHT params?

@adammoody, thanks for the detailed analysis of this bug. To answer your question, no, INFLIGHT params should not be gathered again. Param gathering is asynchronous for performance reasons, and INFLIGHT params are part of ongoing gather operations.

This problem occurs because in RLHF we are context switching 5 models in a rank to share GPU memory. And so, the recommended solution here is to ensure that there are no INFLIGHT params when context-switching out a model. This will guarantee that the actor model params are all NOT_AVAILABLE before the next generate. The empty_partition_cache() can be used for this purpose. I will share a PR asap.

tjruwase commented 1 year ago

@adammoody, can you please try this PR? Thanks!

adammoody commented 1 year ago

Thanks for the explanation and quick reply, @tjruwase . Unfortunately, I'm still hitting the same problem with this PR.

The problematic params seem to be from the first few layers of the actor_model, which have been prefetched due to a forward step of the critic_model. I thought maybe we could move those empty calls to the end to try to clear any INFLIGHT actor params that the critic started to prefetch:

        self.actor_model.backward(actor_loss)
        self.actor_model.step()
        #self.actor_model.empty_partition_cache()

        value = self.critic_model.forward_value(**batch,
                                                return_value_only=True,
                                                use_cache=False)[:, :-1]
        critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
                                                                       start:],
                                          returns, action_mask[:, start:])
        self.critic_model.backward(critic_loss)
        self.critic_model.step()
        #self.critic_model.empty_partition_cache()

        self.actor_model.empty_partition_cache()
        self.critic_model.empty_partition_cache()

However, with that change I get the following error on the self.actor_model.empty_partition_cache() call:

  File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 358, in release_and_reset_all
    raise RuntimeError(f"param {param.ds_summary()} still in flight")
RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()} still in flight
tjruwase commented 1 year ago

The problematic params seem to be from the first few layers of the actor_model, which have been prefetched due to a forward step of the critic_model. I thought maybe we could move those empty calls to the end to try to clear any INFLIGHT actor params that the critic started to prefetch:

This is confusing to me. Each model has independent prefetchers, so the critic_model should not affect the actor_model. Can you share a stack trace that you get with the PR?

tjruwase commented 1 year ago

@adammoody, by the way, I was not able to repro your error on my 4xV100-16GB setup. This makes it harder to resolve.

tjruwase commented 1 year ago

However, with that change I get the following error on the self.actor_model.empty_partition_cache() call:

  File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 358, in release_and_reset_all
    raise RuntimeError(f"param {param.ds_summary()} still in flight")
RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds

By the way, this is a separate bug that needs to be addressed.

tjruwase commented 1 year ago

Should it be lora_param instead of lora_params? Maybe change this to:

if len(lora_param) == 3:

I think you have found a bug here. Do you mind opening a separate ticket for this?

adammoody commented 1 year ago

Should it be lora_param instead of lora_params? Maybe change this to:

if len(lora_param) == 3:

I think you have found a bug here. Do you mind opening a separate ticket for this?

Sure. I'll post that one to the main DeepSpeed repo.

adammoody commented 1 year ago

As another clue, it seems like the following changes work around the problem. I defined a "wait on inflight" function in deepspeed/runtime/zero/partitioned_param_coordinator.py:

    @instrument_w_nvtx
    @torch.no_grad()
    def wait_on_inflight_params(self, current_submodule):
        params = frozenset(iter_params(current_submodule, recurse=True))
        for param in params:
            if param in self.__inflight_param_registry:
                print(param.ds_summary())
                with get_accelerator().stream(self.__allgather_stream):
                    while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
                        self.__ongoing_fetch_events.popleft()
                    if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
                        self.__ongoing_fetch_events.popleft().synchronize()

                    self.__inflight_param_registry.pop(param).wait()

                    event = get_accelerator().Event()
                    event.record()
                    self.__ongoing_fetch_events.append(event)

                assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
        get_accelerator().current_stream().wait_stream(self.__allgather_stream)

I then call that from partition_all_parameters in deepspeed/runtime/zero/parameter_offload.py, which is called from empty_partition_cache:

    def partition_all_parameters(self):
        """Partitioning Parameters that were not partitioned usually if parameters
        of modules whose input parameters do not require grad computation do not
        trigger post call and will therefore will remain unpartitioned"""
        self.get_param_coordinator(training=self.module.training).wait_on_inflight_params(self.module)
        self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)

And then I call empty_partition_cache on both models after training on both:

        ### process the new outputs
        batch = {'input_ids': seq, "attention_mask": attention_mask}
        actor_prob = self.actor_model(**batch, use_cache=False).logits
        actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
                                          inputs['input_ids'][:, 1:])
        actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
                                        log_probs[:, start:], advantages,
                                        action_mask[:, start:])
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        value = self.critic_model.forward_value(**batch,
                                                return_value_only=True,
                                                use_cache=False)[:, :-1]
        critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
                                                                       start:],
                                          returns, action_mask[:, start:])
        self.critic_model.backward(critic_loss)
        self.critic_model.step()

        # call empty_partition_cache after stepping both actor and critic models
        self.actor_model.empty_partition_cache()
        self.critic_model.empty_partition_cache()

If I drop the recurse=True option to change:

params = frozenset(iter_params(current_submodule, recurse=True))

to:

params = frozenset(iter_params(current_submodule))

then I still get this error:

  File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 400, in release_and_reset_all
    raise RuntimeError(f"param {param.ds_summary()} still in flight")
RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()} still in flight

It seems that it misses the necessary parameters without the recursion.

adammoody commented 1 year ago

And if I keep recurse=True, but call empty_partition_cache after each train step individually rather than at the end of both steps like so:

        ### process the new outputs
        batch = {'input_ids': seq, "attention_mask": attention_mask}
        actor_prob = self.actor_model(**batch, use_cache=False).logits
        actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
                                          inputs['input_ids'][:, 1:])
        actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
                                        log_probs[:, start:], advantages,
                                        action_mask[:, start:])
        self.actor_model.backward(actor_loss)
        self.actor_model.step()
        self.actor_model.empty_partition_cache()

        value = self.critic_model.forward_value(**batch,
                                                return_value_only=True,
                                                use_cache=False)[:, :-1]
        critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
                                                                       start:],
                                          returns, action_mask[:, start:])
        self.critic_model.backward(critic_loss)
        self.critic_model.step()
        self.critic_model.empty_partition_cache()

        #self.actor_model.empty_partition_cache()
        #self.critic_model.empty_partition_cache()

I still get the original error noted at the top of the issue:

    weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1

In summary, it seems that I need to do all three of these: 1) call empty_partition_cache after both steps 2) wait on any inflight params 3) call iter_params(..., recurse=True) when getting the parameter list

tjruwase commented 1 year ago

Thanks for sharing these details. I agree that empty_partition_cache needs a wait_on_inflight_params logic like you discovered. However, I would like to take a step back to understand a few things.

First, empty_partition_cache should guarantee that all params are NOT_AVAILABLE. But you mentioned that you were hitting the original problem with my PR which calls empty_partition_cache after actor_model.step(). I don't understand how that can possible. So, could you please share the stack trace of applying my PR?

adammoody commented 1 year ago

The error message and stack trace when using the changes in the PR are the same as the original error report. Right, I haven't given up on figuring out this problem either. I have some more ideas to try to debug things. I'll keep posting updates.

adammoody commented 1 year ago

@tjruwase , I still haven't cracked it, but here are some more clues...

The first problematic layer corresponds to the vocab embedding layer of the actor model. I did verify that layer actually belongs to the actor model, and that it is not shared with the critic model or any other model.

The stack trace for the prefetch of that layer is shown below. The line numbers will vary because I've added lots of debug statements.

  File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
    main()
  File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
    actor_loss, critic_loss = trainer.train_rlhf(exp_data)
  File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 191, in train_rlhf
    value = self.critic_model.forward_value(**batch,
  File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
    transformer_outputs = self.rwtranrsformer(
  File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
    decoder_outputs = self.decoder(
  File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/path/python3.9/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/path/python3.9/site-packages/torch/utils/checkpoint.py", line 96, in forward
    outputs = run_function(*args)
  File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
    hidden_states = self.activation_fn(hidden_states)
  File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
    result = hook(self, input)
  File "/path/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/path/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _pre_forward_module_hook
    self.pre_sub_module_forward_function(module)
  File "/path/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/path/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 492, in pre_sub_module_forward_function
    param_coordinator.fetch_sub_module(sub_module)
  File "/path/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/path/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/path/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 368, in fetch_sub_module
    print(traceback.print_stack(file=sys.stdout))

Note that this happens from a param_coordinator.fetch_sub_module(sub_module) call.

That trace was kicked off by a call to:

value = self.critic_model.forward_value()

The ds_summary output for that the param being prefetched is:

-prefetch: {'id': 0, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 102957056, 'shape': (0,), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()}

I did verify that this layer belongs to the actor model by matching its Python id(param) value and similarly verified that it's not shared with the critic model.

I can see that it fails for me on what I think is the third training step. This appears to be the first step where it has completed a trace and thus has enabled prefetching for the model.

def fetch_sub_module(self, current_submodule: Module) -> None:
<snip>
        # kick off parameter prefetches for upcoming modules
        # don't prefetch if we dont have a completed model trace
        if self.is_complete_trace():

At that point self._param_ids == 674, while it shows 0 for the two previous steps.

Since there seems to be some "model mixing" in this case, one area that caught my eye is the global FWD_MODULE_STACK in deepspeed/runtime/zero/parameter_offload.py.

    @torch.no_grad()
    def pre_sub_module_forward_function(self, sub_module):
        see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

        global FWD_MODULE_STACK
        FWD_MODULE_STACK.append(sub_module)
        if dist.get_rank() == 0:
            print("FWD_MODULE_STACK length", len(FWD_MODULE_STACK), "id(module)", id(sub_module))

        param_coordinator = self.get_param_coordinator(training=sub_module.training)
        param_coordinator.trace_prologue(sub_module)
        if param_coordinator.is_record_trace():
            param_coordinator.record_module(sub_module)
        param_coordinator.fetch_sub_module(sub_module)

You can see I've added a print above. The vocab embedding layer is the 8th or 9th element at the time the problem occurs. I'd have to double check which if you need an exact value.

That list is initialized with a base model.

    def setup_zero_stage3_hooks(self):
        self.hierarchy = 0

        #reset step if in inference mode
        @instrument_w_nvtx
        def _end_of_forward_hook(module, *args):

            if not torch._C.is_grad_enabled():
                self.get_param_coordinator(training=False).reset_step()

        #likely one of them should be enough but just to be safe
        self._register_hooks_recursively(self.module)
        self.module.register_forward_hook(_end_of_forward_hook)

        # Add top module to stack trace
        global FWD_MODULE_STACK
        FWD_MODULE_STACK.append(self.module)
        if dist.get_rank() == 0:
            print("FWD_MODULE_STACK length", len(FWD_MODULE_STACK), "id(module)", id(self.module))
            print(str(self.module))

In this case, I can see that the four base models correspond to the first four elements of that list. I'm not sure what 4 or so modules are stored as the elements in between the base models and the vocab embed layer at the point where I see the problem.

I still haven't tracked down why invoking a function on the critic model could end up fetching params for the actor model, but I wondered if there might be some linkage here.

tjruwase commented 1 year ago

@adammoody, kudos on the intensive debugging. I think I know what might be wrong, but I need your help to confirm. I have updated my PR with some asserts to verify that empty_partition_cache() is behaving as expected. Can you please try the PR again?

adammoody commented 1 year ago

I added those new changes by hand, so my source file line numbers will be different. I have been editing DeepSpeed files in place within my python environment, so it takes some effort to set up a clean environment at this point. Anyway, if you trust that, I hit the tensor dimension mismatch here:

https://github.com/microsoft/DeepSpeedExamples/blob/ce049bee82bd4594209beb2bc0676a44af2b5758/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L77

6: 0: Traceback (most recent call last):
6: 0:   File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
6: 0:     main()
6: 0:   File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
6: 0:     out = trainer.generate_experience(prompts)
6: 0:   File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 106, in generate_experience
6: 0:     seq = self._generate_sequence(prompts)
6: 0:   File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 79, in _generate_sequence
6: 0:     seq = self.actor_model.module.generate(prompts,
6: 0:   File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
6: 0:     self.fuse_lora_weight()
6: 0:   File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 139, in fuse_lora_weight
6: 0:     self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
6: 0:   File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
6: 0:     weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
6: 0: RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1

If I add a second call to assert_empty_partition_cache for the actor_model immediately after the check for the critic model:

        ### process the new outputs
        batch = {'input_ids': seq, "attention_mask": attention_mask}
        actor_prob = self.actor_model(**batch, use_cache=False).logits
        actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
                                          inputs['input_ids'][:, 1:])
        actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
                                        log_probs[:, start:], advantages,
                                        action_mask[:, start:])
        self.actor_model.backward(actor_loss)
        self.actor_model.step()
        self.actor_model.empty_partition_cache()
        assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf step')

        value = self.critic_model.forward_value(**batch,
                                                return_value_only=True,
                                                use_cache=False)[:, :-1]
        critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
                                                                       start:],
                                          returns, action_mask[:, start:])
        self.critic_model.backward(critic_loss)
        self.critic_model.step()
        self.critic_model.empty_partition_cache()
        assert_empty_partition_cache(self.critic_model, 'critic_model after rlhf step')
-->     assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')

then the assertion triggers:

  File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 204, in train_rlhf
    actor_loss, critic_loss = trainer.train_rlhf(exp_data)
  File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 204, in train_rlhf
    assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')
  File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 290, in assert_empty_partition_cache
    assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')
  File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 290, in assert_empty_partition_cache
    assert len(avail_or_inflight_params) == 0, \
AssertionError: actor_model after rlhf critic step empty_partition_cache failed to evict all params: remaining = [0, 1, 2, 3, 387, 388, 4, 5, 389, 390, 6, 7, 391, 392, 8, 9, 393, 394, 10, 11, 12, 16, 17]
tjruwase commented 1 year ago

Thanks for sharing these updates. Adding the second assert for the actor model cache is a really good idea. It is mystery why it fails. This supports your suspicion of a leakage between the parameter partitioning of actor and critic model.

Can you confirm that your critic model is 1.3b or 350m?

tjruwase commented 1 year ago

Also, can you try dropping --enable_hybrid_engine from your command line?

adammoody commented 1 year ago

Yes, I'm actually using a 350m model for the critic. I had a cut-and-paste typo in the path name when I wrote out the checkpoint, so the path suggests it's a 1.3b param model, but it is really 350m.

I tried dropping the --enable_hybrid_engine option. I do still hit the assertion. I think it triggered one step earlier than before.

AssertionError: actor_model after rlhf critic step empty_partition_cache failed to evict all params: remaining = [0, 1, 2, 3, 387, 388, 4, 5, 389, 390, 6, 7, 391, 392, 8, 9, 393, 394, 10, 11, 12, 16, 17]

Here are some other work arounds that I found earlier but didn't list yet:

tjruwase commented 1 year ago

Thanks for the update.

  1. Hitting the assertion is good since it stops at the earliest violation of the invariant of empty_partition_cache().
  2. It is good to know that --enable_hybrid_engine is not the cause.
  3. Yes, stage 2 should not cause this since it is stage 3 specific.
  4. Disabling prefetching (can be done through ds_config) should avoid this problem.
tjruwase commented 1 year ago

@adammoody, FYI I think this DeepSpeed PR from my colleague @HeyangQin might be relevant here. Please give him a bit more time to get it ready.

adammoody commented 1 year ago

@tjruwase , I think I found the cause.

I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in setup_zero_stage3_hooks(). When invoking the forward pass on the ReLU module from the critic model, the hook from the actor model is invoked, which leads to the prefetch of the actor layers.

I found this by adding the following code in deepspeed/runtime/zero/parameter_offload.py to print object addresses of all child modules of each model:

def print_children(module, indent):
  for name, m in module.named_children():
    spaces = " " * indent
    print(spaces, name, id(m))
    print_children(m, indent + 2)

<snip>

    def setup_zero_stage3_hooks(self):
        self.hierarchy = 0

        #reset step if in inference mode
        @instrument_w_nvtx
        def _end_of_forward_hook(module, *args):

            if not torch._C.is_grad_enabled():
                self.get_param_coordinator(training=False).reset_step()

        #likely one of them should be enough but just to be safe
        self._register_hooks_recursively(self.module)
        self.module.register_forward_hook(_end_of_forward_hook)

        # Add top module to stack trace
        global FWD_MODULE_STACK
        FWD_MODULE_STACK.append(self.module)
        if dist.get_rank() == 0:
            print("FWD_MODULE_STACK SETUP length", len(FWD_MODULE_STACK), "id(module)", id(self.module), type(self.module))
            print(str(self.module))
            for p_id, param in enumerate(iter_params(self.module, recurse=True)):
              key = id(param) if hasattr(param, 'ds_id') else id(param.ds_param_alias)
              print("  ", p_id, id(param), type(param), key)
            print_children(self.module, 2)

With that, I get the following example output for the actor and reward models. You can see that the activation_fn module has the same address for all layers in all models.

FWD_MODULE_STACK SETUP length 1 id(module) 35188069303104 <class 'transformers.models.opt.modeling_opt.OPTForCausalLM'>
3: 0:    model 35188109644992
3: 0:      decoder 35188109647776
3: 0:        embed_tokens 35188109647248
3: 0:        embed_positions 35188109647440
3: 0:        layers 35188109646144
3: 0:          0 35188109647392
3: 0:            self_attn 35188109644800
3: 0:              k_proj 35188109646480
3: 0:                lora_dropout 35188109646576
3: 0:              v_proj 35188109644320
3: 0:                lora_dropout 35188092098880
3: 0:              q_proj 35188109644848
3: 0:                lora_dropout 35188092098832
3: 0:              out_proj 35188109647296
3: 0:                lora_dropout 35188092096912
3: 0: -->        activation_fn 35186791027472  <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109645040
3: 0:            fc1 35188109645616
3: 0:              lora_dropout 35188092098736
3: 0:            fc2 35188109644656
3: 0:              lora_dropout 35188092100560
3: 0:            final_layer_norm 35188108844336
3: 0:          1 35188108843472
3: 0:            self_attn 35188108843328
3: 0:              k_proj 35188109643984
3: 0:                lora_dropout 35188107776880
3: 0:              v_proj 35188108844432
3: 0:                lora_dropout 35188107776592
3: 0:              q_proj 35188108842800
3: 0:                lora_dropout 35188107777408
3: 0:              out_proj 35188108841264
3: 0:                lora_dropout 35188107777456
3: 0: -->        activation_fn 35186791027472  <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188108841456
3: 0:            fc1 35188108842512
3: 0:              lora_dropout 35188107777696
3: 0:            fc2 35188108841888
3: 0:              lora_dropout 35188107776160
3: 0:            final_layer_norm 35188104685504

<snip>

FWD_MODULE_STACK SETUP length 4 id(module) 35188109548272 <class 'utils.model.reward_model.RewardModel'>
3: 0:    v_head 35188109667440
3: 0:    rwtranrsformer 35188109548224
3: 0:      decoder 35188109548320
3: 0:        embed_tokens 35188109548416
3: 0:        embed_positions 35188109548368
3: 0:        project_out 35188109548512
3: 0:        project_in 35188109548560
3: 0:        layers 35188109548656
3: 0:          0 35188109548608
3: 0:            self_attn 35188109548704
3: 0:              k_proj 35188109548800
3: 0:              v_proj 35188109548848
3: 0:              q_proj 35188109548992
3: 0:              out_proj 35188109549040
3: 0: -->        activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109548752
3: 0:            fc1 35188109549136
3: 0:            fc2 35188109549184
3: 0:            final_layer_norm 35188109549232
3: 0:          1 35188109549328
3: 0:            self_attn 35188109549376
3: 0:              k_proj 35188109549472
3: 0:              v_proj 35188109549520
3: 0:              q_proj 35188109664320
3: 0:              out_proj 35188109664368
3: 0: -->        activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109549424
3: 0:            fc1 35188109547552
3: 0:            fc2 35188109547792
3: 0:            final_layer_norm3: 0:  35188109547840

As a test, I then found that I could work around the problem by modifying the OPT model to instantiate a unique ReLU object for each layer in transformers/models/opt/modeling_opt.py:

class OPTDecoderLayer(nn.Module):
    def __init__(self, config: OPTConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
        self.dropout = config.dropout
-->     #self.activation_fn = ACT2FN[config.activation_function]
-->     self.activation_fn = nn.ReLU()
tjruwase commented 1 year ago

@tjruwase , I think I found the cause.

I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in setup_zero_stage3_hooks(). When invoking the forward pass on the ReLU module from the critic model, the hook from the actor model is invoked, which leads to the prefetch of the actor layers.

Amazing debugging, @adammoody. Truly outstanding!

@stas00, FYI. It seems ReLU objects are shared across models in the same transformer process. Do you have context for this behavior?

stas00 commented 1 year ago

yes, it's the same object, it's like a cache:

creation: https://github.com/huggingface/transformers/blob/b6865b9befad33f99adee0a6ef6361f72fcc8b42/src/transformers/activations.py#L206-L233

use: https://github.com/huggingface/transformers/blob/b6865b9befad33f99adee0a6ef6361f72fcc8b42/src/transformers/models/opt/modeling_opt.py#L288

The paradigm is shifting. Clearly there was no need to create a new object before because deepspeed won't support more than one model. And there is no issue with reusing the same object with multiple models outside of deepspeed world.

Probably should file a feature request to create these on the fly, rather the pre-create. So that each instance will be unique.

There are quite a few changes that need to be made to support multiple deepspeed models paradigm.

Some possible workarounds:

liuaiting commented 1 year ago

raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") this error is still reported when I running step3 use bloomz + zero3.

I encountered this bug after the previous bug (https://github.com/microsoft/DeepSpeed/issues/3528) was solved. @HeyangQin

kisseternity commented 1 year ago

raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") this error is still reported when I running step3 use bloomz + zero3.

I encountered this bug after the previous bug (microsoft/DeepSpeed#3528) was solved. @HeyangQin

I have the same issue. Have you resolved the problem?

shyustc commented 1 year ago

raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") this error is still reported when I running step3 use bloomz + zero3.

I encountered this bug after the previous bug (microsoft/DeepSpeed#3528) was solved. @HeyangQin

Encountering the same error here. The issue persists even after updating DeepSpeed and PyTorch Lightning to the latest versions.