microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.62k stars 4.04k forks source link

Resizing HF token embeddings with PipelineModule #1010

Open g-karthik opened 3 years ago

g-karthik commented 3 years ago

With a HF model class, one can resize token embeddings to account for any special tokens, there's no upper limit, i.e., in the usual scenario (this isn't necessarily working code, I may have gotten the tokenizer APIs incorrect):

from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
config_class = GPT2Config
model_class = GPT2LMHeadModel
tokenizer_class = GPT2Tokenizer

config = config_class.from_pretrained("gpt2-xl")   # let's say we want to use the XL config for now, has its own vocab size
tokenizer = tokenizer_class.from_pretrained("gpt2-xl")  # default XL vocab

tokenizer.add_special_tokens("<speaker1>")
tokenizer.add_special_tokens("<speaker2>")

model = model_class(config)
model.resize_token_embeddings(len(tokenizer))

The last line essentially allocates 2 new indices for the newly added special tokens in the input embeddings matrix, and initializes their embeddings with random weights.

Now in the pipeline regime, one cannot just resize the token embeddings after initialization of the PipelineModule, since the module would have already split the model across pipeline stages. Is it possible to provide a callback/mechanism with PipelineModule that can allow for resizing and fresh initialization of newly added special token embeddings for downstream users?

Also, shouldn't this be a problem with the implementation of pipeline (and more generally 3D) parallelism in the DeepSpeedExamples repo too? A user of a model that's been pre-trained with pipeline parallelism would certainly have some basic downstream needs such as addition of special tokens for fine-tuning.

@ShadenSmith @stas00

stas00 commented 3 years ago

I don't know much about DS's version of pipeline - I worked only with the pytorch's native version of it.

I know I had to add gather and re-partition with Zero-3 for that exact situation of resizing embeds - twice in this code: https://github.com/huggingface/transformers/blob/8d43c71a1ca3ad322cc45008eb66a5611f1e017e/src/transformers/modeling_utils.py#L643

Does pipeline have a similar feature? I just don't know this side of DS

g-karthik commented 3 years ago

@stas00 have you already migrated HF towards native pipeline parallelism with pytorch 1.8? If so, can you point me to that?

For pipe, I don't think there's an equivalent feature to the one you listed @stas00, although I have a few questions about that:

  1. Does that method you linked work if DeepSpeed ZeRO-3/infinity were enabled outside of the Trainer class? I do not use the Trainer but I use DeepSpeed with HF model classes. Does with deepspeed.zero.GatheredParameters make any assumptions about pre-initialization, such as deepspeed.zero.Init?
  2. Have you run basic performance tests for ZeRO-3/infinity with HF classes?

Would also like @ShadenSmith and @tjruwase's thoughts here because I think 2 above (i.e., ZeRO perf on HF classes) is incredibly poor, as seen with ZeRO-2. I have a separate GitHub issue ongoing with @tjruwase about this. I feel like pipeline or even 3D parallelism would be better than ZeRO-3/infinity because of the reduced communication volume with pipeline parallelism.

Fitting massive models in GPU memory is one thing, being able to train them fast by minimizing communication volume is another thing. ZeRO-3/infinity may help with the former, but it looks like pipeline (or more broadly, 3D) parallelism is the better solution because it also allows the latter.

stas00 commented 3 years ago

re: Pipeline Parallelism:

Most HF models are too complicated. All Pipe-approached that I tried require:

  1. models converted nn.Sequential
  2. inputs/outputs to be Tensors

So after spending weeks on this I gave up (or rather parked the idea). I managed to make a pipeline using 2 pytorch pipelines because pipelines can't handle conditional modules, which encoder/decoder models are. The performance was terrible, I couldn't get over the 50% gpu util over 2 gpus.

pytorch Pipe API has been becoming more user-friendly wrt (2) and will soon handle any input/outputs.

In order to convert HF models to pipeline the models have to drop complex feature like past key and hidden states aggregates - this was the most difficult part. I made a workaround using closures but it doesn't scale well. If you want to see some really crazy code that experimental PR is full of it.

Bottom line - to make pipelines work the models

  1. they need to be designed to be easily convertable to nn.Sequential
  2. if we want DeepSpeed pipeline they must not use complex input/outputs or bool control variables (though the latter can be worked around).

re: Performance testing

I hope to start doing that in the next few days, now that ZeRO-Inf has been merged. As usual I will make an Issue on HF transformers and start sharing the results.

We plan to do an extensive benchmarking including sagemaker, jax, megatron-lm and deepspeed, of course. I'm not sure if fairscale will be included - last one of us looked it was not complete, but perhaps they have caught up - I was too busy with deepspeed integration and dealing with bf16-pretrained models getting NaNs under fp16/mixed precision/deepspeed to have time to look.

One other approach I hope to include is FlexFlow https://github.com/flexflow/flexflow - I hope we will now be able to convert our models to pytorch.fx trace - which is a prerequisite for flexflow. I highly recommend you check it out - the paper looks very interesting - but I haven't had a chance to see it in action yet and hope this will change soon. @michaelbenayoun has been making an awesome progress proxying the symbolic tracing via https://github.com/huggingface/transformers/pull/11475 which should enable flexflow usage with HF transformers.