huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.08k stars 1.28k forks source link

Speed up ZeRO-3 generation with DPO #1543

Closed sngdng closed 6 months ago

sngdng commented 7 months ago

Hi, a recent PR brought large improvements (x10) to PPO generation with ZeRO-3. @lewtun, you mention on the PR that it can be adapted for other trainers. I gave it a quick shot and it seems that naive applying the context manager to trainers like DPO does not work:

in remove_hooks
    if model.optimizer is not None and hasattr(
       ^^^^^^^^^^^^^^^^^^^^
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GPTNeoXForCausalLM' object has no attribute 'optimizer'

There seems to be an inconsistency between the base classes. Is there a reason why DPO is based on Trainer from transformers and PPO on BaseTrainer ? What would be the easy way to add this feature to other trainers ? Thanks !

sngdng commented 7 months ago

Passing self.model_wrapped instead in unwrap_model_for_generation in gives:

deepspeed/runtime/zero/partitioned_param_coordinator.py", line 194, in record_parameters
    step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: pop from an empty deque

Is it related to the way the model removes/adds hooks ?

lewtun commented 6 months ago

Hey @sngdng we've just opened a PR to fix the issue - please let us know if it still gives you an error!

Shiguang-Guo commented 6 months ago

I just install trl from source, so I think I have applied the latest fix, but I still get the same error when running example/scripts/ppo.py with deepspeed_zero3. The first two batches ran fine, but the third batch crashed. Maybe the only difference is that I use llama-2-7b-chat. Do you have any suggstions?

lewtun commented 6 months ago

I just install trl from source, so I think I have applied the latest fix, but I still get the same error when running example/scripts/ppo.py with deepspeed_zero3. The first two batches ran fine, but the third batch crashed. Maybe the only difference is that I use llama-2-7b-chat. Do you have any suggstions?

Can you please share the exact command you're running to trigger the error?

Shiguang-Guo commented 6 months ago

only accelerate launch ${ENV_ARGS} --config_file=deepspeed_zero3.yaml ppo.py ${TRAIN_ARGS}. ${ENV_ARGS} contains the node, address and ${TRAIN_ARGS} just tells the script where to load the model from. The deepspeed configuration file and ppo.py are both from the examples. By the way, I opened an new issue here(#1618) with some extra log. Because I found the original question about DPO, and I'm trying PPO. Thank you for any suggestion.

sngdng commented 6 months ago

@lewtun I can confirm that the issue still persist even with the fix without the context manager it works but it is super slow.. with the context manager it still gives:

... line 3045, in training_step
    self.accelerator.backward(loss)
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/accelerate/accelerator.py", line 1960, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
    self.engine.backward(loss, **kwargs)
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1974, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/runtime/zero/stage3.py", line 2214, in backward
    self._get_param_coordinator(training=True).reset_step()
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 216, in reset_step
    self.construct_parameter_trace_from_module_trace()
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 202, in construct_parameter_trace_from_module_trace
    self.record_parameters(sub_module)
  File "/gpfslocalsup/pub/anaconda-py3/2023.09/envs/pytorch-gpu-2.2.0+py3.11.7/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 194, in record_parameters
    step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: pop from an empty deque
c3ianwu commented 2 weeks ago

@sngdng @lewtun I’m running into this exact issue (pop from empty deque), also when doing backwards. Did either of you figure out what’s causing this?