OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.73k stars 164 forks source link

Why is dschf defined in function scope? #196

Closed kajyuuen closed 5 months ago

kajyuuen commented 5 months ago

To enable DeepSpeed ​​Zero Stage 3, you need to create and configure an persist instance of HfDeepSpeedConfig.

https://huggingface.co/docs/transformers/main_classes/deepspeed#non-trainer-deepspeed-integration

However, it appears that the current code dares to define it within the function scope.

https://github.com/OpenLLMAI/OpenRLHF/blob/fad3227afd3124ad41505974b49a24b953eba7ed/openrlhf/models/actor.py#L53-L56

Could you please explain the reason for this?

wuxibin89 commented 5 months ago

Internally, HfDeepSpeedConfig set a special weakref global object which may affect all transformers behavior. We only want HfDeepSpeedConfig take effects(enable deep.zero.Init) when initializing model with AutoModelForCausalLM.from_pretrained, so we define it locally and let gc collect it when the function return.

https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/deepspeed.py#L75

def set_hf_deepspeed_config(hf_deepspeed_config_obj):
    # this is a special weakref global object to allow us to get to Deepspeed config from APIs
    # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
    global _hf_deepspeed_config_weak_ref
    # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed)
    _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)