huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.83k stars 27.19k forks source link

[resume optimization] skip loading pretrained weights on resume #11465

Open stas00 opened 3 years ago

stas00 commented 3 years ago

This is similar to what was discussed in https://github.com/huggingface/transformers/issues/9205, which proposed not to random init weights on from_pretrained, but this time it's about resume - currently we load pretrained weights and immediately drop them on resume from checkpoint in Trainer.

To solve this we, for example, could change examples:

  1. to figure out the checkpoint immediately after we init TrainingArguments and just before model is created.
  2. then change from_pretrained() API to do keep everything as is, except loading the weights from state_dict, if say skip_weights_load=True is passed:

So the code becomes:


    if training_args.do_train:
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None

    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path, 
        [...],
        skip_weights_load=checkpoint is not None,
    )

    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)

Any flaws in my thinking?

@patrickvonplaten, @sgugger

sgugger commented 3 years ago

This can be achieved by just doing model = AutoModelForSeq2SeqLM.from_config(config) when the checkpoint is not None. I don't believe it will be much faster however as your analysis in #9205 pointed out to the random initialization as being the bottleneck.

stas00 commented 3 years ago

This can be achieved by just doing model = AutoModelForSeq2SeqLM.from_config(config) when the checkpoint is not None.

From here, right? https://github.com/huggingface/transformers/blob/88ac60f7b5f6d4b62245dc21653ea3d5db7d4935/src/transformers/models/auto/auto_factory.py#L362

Great idea!

Then this important part would be missed:

            with deepspeed.zero.Init():
                model = cls(config, *model_args, **model_kwargs)

I guess I need to add it to from_config anyway, which would solve this part

and also this won't be done:

   model.eval()

but the latter is probably redundant anyway.

I don't believe it will be much faster however as your analysis in #9205 pointed out to the random initialization as being the bottleneck.

For huge models every saving counts! once you start working with models like t5-11b it's excruciatingly slow to wait for things to start.

Should I try one example and re-shuffle the order of the code?

sgugger commented 3 years ago

Yes, we should try on one example first! Though the first step is to fix the from_config method of AutoModel :-)

stevemadere commented 8 months ago

I had this same problem and created a fork with a workaround. https://github.com/stevemadere/transformers/commit/3561eefe6162f6e768dd303f574aa7cef5976ba4

It adds a new parameter without_checkpoint_model: bool = False, to the Trainer.train() method. If that parameter is set to True, resume_from_checkpoint will skip over the model re-loading step and raises an error if no model was provided in the Trainer constructor.

I've integration tested it pretty thoroughly, having resumed training of a QLoRA model many, many times at this point.

Basically, I needed to resume training of a QLoRA model and did not have the patience to wait for all the kinks to be worked out in re-loading QLoRA models from within Trainer. It looks like there is supposed support for resuming with QLoRA added very recently but then some folks complaining that it does not actually work in practice:

In general, whenever there is a new model type that the Trainer does not explicitly know how to load, it will be impossible to resume_from_checkpoint for that kind of model until support is added to Trainer for loading that type of model.

We might do well to just provide folks a way to work around that without having to use my fork. :)

Let me know if you want me to turn that diff into a PR on this repo.

stas00 commented 8 months ago

I'm no longer at HF to make decisions and this thread is really really old - and now the solution is to create the model on the meta device and then replace it with pre-trained checkpoint.

stevemadere commented 8 months ago

I'll create a new issue for my problem which is apparently a bit different because I'm not concerned so much about efficiency as about the ability to continue training at all with a model that is not yet supported for reloading by Trainer.