huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.82k stars 26.96k forks source link

Explicit option to disable deepspeed when loading a model #28106

Closed chiragjn closed 9 months ago

chiragjn commented 11 months ago

Feature request

Option to disable deepspeed explicitly on a per-model basis

Motivation

So I have a little bit of an odd setup In my qlora/lora fine-tuning script, I launch with accelerate launch --mixed_precision bf16 --use_deepspeed train.py --deepspeed deepspeed_zero3.json ... and I am using the TrainingArguments class to accept this config

In that script, before I start training, I want to load the model with empty weights without deepspeed involved

But once a deepspeed zero 3 config is set, it gets set as a global https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/integrations/deepspeed.py#L239

And then all models try to use Deepspeed Zero init or do special handling for Zero 3 sharding https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/modeling_utils.py#L1823

This results in error with meta devices

    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
  File "/data/v/ft/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 441, in from_config
    return model_class._from_config(config, **kwargs)
  File "/data/v/ft/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1247, in _from_config
    model = cls(config, **kwargs)
  File "/data/v/ft/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/data/v/ft/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1141, in __init__
    self.model = MixtralModel(config)
  File "/data/v/ft/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/data/v/ft/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 964, in __init__
    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  File "/data/v/ft/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 466, in wrapper
    self._post_init_method(module)
  File "/data/v/ft/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 995, in _post_init_method
    param.data = param.data.to(self.local_device)
NotImplementedError: Cannot copy out of meta tensor; no data!

While I can work around my issue, I thought it might be good to have some context manager to disable deepspeed zero in certain sections of the code


Additional context on why I load my model separately

Before I start training I just do a check to ensure the base model can fit entirely within the available GPUs in bf16 format. This is to ensure that after tuning I would be able to merge the adapters correctly because currently merge and unload cannot save offloaded modules correctly (A fix for that is under progress See: https://github.com/huggingface/peft/pull/1190)

The code for this check looks like this

# Check if model can fit just with gpus
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
device_map = infer_auto_device_map(model, dtype=torch.bfloat16)
logger.info(f"Inferred device_map for auto settings: {device_map}")
if any(not isinstance(v, int) for v in device_map.values()):
    raise RuntimeError(...)

Your contribution

#

amyeroberts commented 10 months ago

cc @pacman100

pacman100 commented 10 months ago

Hello @chiragjn,

Can you try to do the below and let us know if that solves this issue as we already have the context manager zero3_init_context_manager which controls the zero init:

def main():
    trainer_args = TrainingArguments(<fill this>)
    with trainer_args.deepspeed_plugin.zero3_init_context_manager(enable=False):
        # Check if model can fit just with gpus
        config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
        device_map = infer_auto_device_map(model, dtype=torch.bfloat16)
        logger.info(f"Inferred device_map for auto settings: {device_map}")
        if any(not isinstance(v, int) for v in device_map.values()):
            raise RuntimeError(...)
chiragjn commented 10 months ago

Ah nice to know this exists, I just checked, and it seems like my problem still occurs and is not just zero init related. Because is_deepspeed_zero3_enabled does not care about zero init enabled or not.

I was able to work around my issue, thanks for pointing me in the right direction

from transformers.integrations.deepspeed import (
    is_deepspeed_zero3_enabled,
    set_hf_deepspeed_config,
    unset_hf_deepspeed_config,
)

@contextlib.contextmanager
def temporarily_disable_deepspeed_zero3(training_arguments: TrainingArguments):
    if training_arguments.deepspeed and is_deepspeed_zero3_enabled():
        unset_hf_deepspeed_config()
        yield
        set_hf_deepspeed_config(training_arguments.hf_deepspeed_config)
    else:
        yield

Note for readers: This ^ works only for accelerate launch script.py --deepspeed ... If you use accelerate launch --deepspeed_config_file ... script.py ... then the handling has to be a little bit different

set_hf_deepspeed_config(training_arguments.hf_deepspeed_config) would change to set_hf_deepspeed_config(training_arguments.deepspeed_plugin.dschf)


It would be nice to have something like this in the library

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

meng-zha commented 6 months ago

Hi, is there some elegant ways to avoid zero3_init for _frompretrained() method now? If I set zero3 in training_arguments, I cannot control whether use zero3_init or not.