huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.3k stars 26.62k forks source link

Flan-T5-XXL generates non-sensical text when load_in_8bit=True #20287

Closed jimmy-marmalade closed 1 year ago

jimmy-marmalade commented 1 year ago

System Info

Who can help?

@patrickvonplaten

Information

Tasks

Reproduction

Running the English to German example:

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

produces expected output:

<pad> Wie alt sind Sie?</s>

Loading in 8-bit and running:

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", load_in_8bit=True)

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

results in output more nonsensical than I'd expect:

<pad> How old are</s>

Expected behavior

I expected close or approximate output between the original output and the 8-bit output. This was the provided INT8 code snippet so expected output to be sensible for task.

sgugger commented 1 year ago

cc @younesbelkada

younesbelkada commented 1 year ago

Hi @jimmy-marmalade Thanks a lot for raising this point. Note that int8 quantization is done in 2 stages, it first converts the model in float16 and uses the fp16 model to quantize it in 8bit. If you try to load and run the model in fp16 you also get gibberish output:

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("./flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("./flan-t5-xxl", torch_dtype=torch.float16, device_map="auto")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids, max_length=512)
print(tokenizer.decode(outputs[0]))
>>> <pad> How old are die Sie? Ihr Mutter?tat?ztlich, rezult Interesse: restriction = = = = ...

I suspect there is something wrong with bf16 to fp16 conversion for this specific model and for xl model too. @stas00 do have you any intuition on why the int8 conversion (so the underlying fp16 conversion) worked well for bloom-176 and not here? 🙏 Thanks!

stas00 commented 1 year ago

Did the bf16 model weights have large values resulting in overflow when used under fp16? bf16 to fp16 conversion is a huge issue with almost every model we have seen - .e.g all large t5 and derivative models. Seeing that this is a T5 derivative it's almost certainly related, you can see the rich discussion here: https://github.com/huggingface/transformers/pull/10956 and possible workarounds to try.

Probably talk to @TimDettmers and ask if perhaps there could be bf16/int8 variation for bf16 models in BNB?

sgugger commented 1 year ago

T5 doesn't work in FP16 because the softmaxes in the attention layers are not upcast to float32. @younesbelkada if you remember the fixes done in BLOOM/OPT I suspect similar ones would fix inference in FP16 for T5 :-)

stas00 commented 1 year ago

but why are we even going through FP16 to do quanitization of a bf16 model? why can't this be done directly in the original dtype?

note: interestingly deepspeed-inference also converts the model to fp16 to do quantization.

younesbelkada commented 1 year ago

Thank you very much for all the pointers

T5 doesn't work in FP16 because the softmaxes in the attention layers are not upcast to float32. @younesbelkada if you remember the fixes done in BLOOM/OPT I suspect similar ones would fix inference in FP16 for T5 :-)

I think that T5 already upcasts the softmax to fp32. I suspected that the overflow might come from the addition to positional bias in the line before but did not helped. I also tried to upcast the lm_logits and the hidden_states before the lm_head to fp32 but did not helped too.

In addition, I also printed the hidden states at every stage, checking whether it contains any nan or inf. This was always set to False.

I will investigate more by reading in deep details @stas00 PR

I think the most efficient solution is to try to see where the overflow comes, and force the operation to be done in fp32 in this operation.

younesbelkada commented 1 year ago

@stas00 using your tool detect_overflow here, and I am flagging an overflow starting from a certain layer (from layer 6 - because of inf). (btw I also tried the autocast solution but seems to not work for inference, I have seen on the issue that it might work for some situations for inference, but sadly not in this case :/ ) Then the hidden states gets constantly overflowed and clamped at each layer. I guess that the accumulation of clamping at various stages introduces these errors.

I am wondering if we can find a proper workaround for that, is clamping the right solution? My intuition is that the model gets confused at some point since clamping n times the hidden states will yield to completely different results than the bf16 hidden states.

Also, let's say you have flagged the layer responsible of the overflow. You can then do the operation there in fp32. But the under/overflow will be still present since you need to cast the results back in fp16 right?

stas00 commented 1 year ago

There was one more scaling hack posted here if you want to try it. https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628

In general bf16-pretrained models ought to run under bf16 or fp32 regimes, as fp16 and bf16 are very incompatible dtypes. It's not as bad if you were to go from fp16 to bf16 as you'd only lose precision, and it'd only impact quality, but not the other way around (overflow).

So we should raise this question with the BNB and perhaps deepspeed-inference developers, at least to have an understanding of why both require fp16 and won't support bf16.

@TimDettmers, @RezaYazdaniAminabadi - is there a way for your libraries to work with bf16 dtype, so that the bf16-pretrained models won't overflow during inference? Thank you.

jimmy-marmalade commented 1 year ago

Thanks everyone for the discussion and work!

Are there any possible workarounds that I could implement as an end user?

RezaYazdaniAminabadi commented 1 year ago

There was one more scaling hack posted here if you want to try it. #14189 (comment)

In general bf16-pretrained models ought to run under bf16 or fp32 regimes, as fp16 and bf16 are very incompatible dtypes. It's not as bad if you were to go from fp16 to bf16 as you'd only lose precision, and it'd only impact quality, but not the other way around (overflow).

So we should raise this question with the BNB and perhaps deepspeed-inference developers, at least to have an understanding of why both require fp16 and won't support bf16.

@TimDettmers, @RezaYazdaniAminabadi - is there a way for your libraries to work with bf16 dtype, so that the bf16-pretrained models won't overflow during inference? Thank you.

The main reason the weights are converted to half on DeepSpeed-side is that the kernels are only working with fp16 values. However, we are looking into some of these limitations and will resolve them soon. The other part is that we can quantize from the original bf16 checkpoint and resolve some of the overflow issue due to different data-precision of fp16 vs bf16.

stas00 commented 1 year ago

Wonderful!

So it looks like @jimmy-marmalade can try out your solution (Deepspeed-Inference) once you have something working in bf16/int8, Reza, and hopefully this will unblock them.

jimmy-marmalade commented 1 year ago

Thanks @RezaYazdaniAminabadi is there an issue I can watch to keep track of progress.

larsmennen commented 1 year ago

@younesbelkada (cc @thomwolf who gave the inspiration for the workaround; cc @stas00 , @sgugger ):

@navjotts and myself had a look at this and found a workaround.

As already concluded in ticket above, the bf16->fp16 conversion is generally incompatible. We ran the detect_overflow as well (great tool, thanks @stas00 !) and found generally we got overflows in the dense part of layer 7 in the encoder, specifically in the wo operation of T5DenseGatedActDense.

We implemented a hacky workaround to keep wo in fp32, cast its input to fp32 and then leave it in fp32 until after the T5LayerNorm. At the end of the norm we cast back to fp16. All fp16 linear modules (i.e. everything except the wo) can then use the 8-bit quantization. The cast back to fp16 is not lossless ofcourse, but we've generally found it to perform equivalent. We haven't spotted any difference in output so far.

We made 3 changes:

  1. T5DenseGatedActDense.forward:

    hidden_gelu = self.act(self.wi_0(hidden_states))
    hidden_linear = self.wi_1(hidden_states)
    hidden_states = hidden_gelu * hidden_linear
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.wo(
        hidden_states.to(torch.float32)
    )  # PATCH: Cast to float32, as self.wo is casted to float32 w/ patch 3 below
    return hidden_states
  2. T5LayerNorm.forward

    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    
    # convert into half-precision if necessary
    if self.weight.dtype in [torch.float16, torch.bfloat16]:
        hidden_states = hidden_states.to(self.weight.dtype)
    
    return (self.weight * hidden_states).to(
        torch.float16
    )  # PATCH: Cast back to float16 for compatibility w/ next layer. This is not lossless.
  3. _load_state_dict_into_meta_model

        if param_name.endswith(
                '.wo.weight'
        ):  # PATCH: For the wo weights of the dense layers, keep them in float32, others will get converted to float16 as this is a requirement for the LLM 8-bit quantization.
            param = param.to(torch.float32)  # PATCH
        else:  # PATCH
            # We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
            # in int/uint/bool and not cast them.
            if dtype is not None and torch.is_floating_point(param):
                param = param.to(dtype)
            # For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
            if is_safetensors and dtype is None and torch.is_floating_point(param):
                param = param.to(torch.float32)

and then when instantiating the model from pretrained we set:

load_in_8bit_skip_modules=['decoder', 'lm_head',  'wo']

I'm not sure what a good way would be to get this into transformers though / if that would even be a good idea given this is quite hacky, curious for your thoughts. For patch 3, if we could add an option to specify an exclude_list for the conversion to float16, that would remove the need for that patch. Then the layers can be adapted at model-level.

younesbelkada commented 1 year ago

Hi @larsmennen (and cc @thomwolf ) Thanks for this great investigation and amazing fix - I also believe that this approach is the best fix so far for this problem. Thanks so much for working on this as it will enable using these modeis in a more accessible way. I see 2 workaround for that

1- The fix should be applied for 8bit models only, in this case, I think that we can perfectly have an additional flag load_in_8bit_fp32_modules = ["wo"] and apply a patch similar to your point 3.. For the points 2 and 1 we can probably have a hot fix as you suggested but I would wait for @sgugger and/or @stas00 to hear what they think about that

2- We add an additional flag, regardless if the model is loaded in 8bit or no, since this could fix the issue with T5-fp16 too, with a flag similar than above keep_in_fp32_modules=["wo"] that is activated only for half precision models (and 8bit models too). But again we'll probably need the hotfixes from 1&2.

Few questions:

sgugger commented 1 year ago
  1. and 2. are easy patches to integrate. I don't anticipate any difficulty to have merged as is. For 3 we need a solution that is not too ad-hoc in the code of modeling_utils. I like the proposed 2- in @younesbelkada comment, since I also believe this should also fix the T5 evaluation problem in FP16.

Thanks a lot for the clear explanations and the deep dive!

stas00 commented 1 year ago

what about users on bf16 hardware that don't want to waste memory casting to fp32 since the modeling code works just fine when using bf16 mixed precision?

I think if this is done it should only be done for fp16 mixed precision.


Also please note that we automatically use apex's faster layernorm when it's found, so T5LayerNorm.forward workaround will apply only if it's not found. i.e. you may want to disable the swap-in of the faster version.

sgugger commented 1 year ago

I think if this is done it should only be done for fp16 mixed precision.

Yes indeed that's a very good point! (cc @younesbelkada)

stas00 commented 1 year ago

Also what about users pre-training their own model in fp16, the proposed change will negatively impact them as well, as the current modeling code should work just fine for them.

IMHO, the safest approach would be to leave the current code alone and have a flag that activates workaround solutions for those who need them.

Additionally, I remind you that there were other workarounds proposed that don't use any additional memory and use a scaling factor instead that moves the weights into a safe-to-fp16 numerical range. https://github.com/huggingface/transformers/pull/10956#issuecomment-961030728

sgugger commented 1 year ago

Also what about users pre-training their own model in fp16, the proposed change will negatively impact them as well, as the current modeling code should work just fine for them.

The main goal of having T5 in the library is to support the corresponding pretrained models as best as possible. All of T5, FlanT5 and T0 checkpoints have been trained in bfloat16, so changing the code to support fp16 inference is for the better for the larger community. If this slows down an edge case, the user can just adapt the line of code in the modeling file to suit their need (that's why the library is not modular and with a strict one file per model policy after all :-) ).

stas00 commented 1 year ago

One could argue that this breaks backward compatibility since suddenly the model requires more memory to operate than when it was originally released.

If the belief is that the majority X% of users will benefit from such BC breaking change I think it'd at least be nice to have a flag for the small Y% to be able to retain what they have been using w/o needing to manually change the code.

Might be much simpler to clone this to models/t5-bf162fp16 apply all the proposed patches and thus have 2 versions - one for normal use and one for the originally unintended bf162fp16 use.

larsmennen commented 1 year ago

Thanks for the quick replies all!

@younesbelkada to answer your questions:

When using load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo'] note that with your fix decoder and lm_head will be kept to their native dtype. In the case you are calling load_in_8bit=True, we first cast all the weights in fp16 therefore load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo'] is probably not needed as wo is "force-casted" in fp32 and lm_head is always detected to be kept in its original precision. Can you double check that? 🙏

Regarding need for wo: if we don't pass it in, then it is not ignored from the conversion of the linear layers to 8bit, and an autocast is applied:

/root/venv/lib/python3.7/site-packages/bitsandbytes/autograd/_functions.py:231: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")

and then resulting model:

          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
          --> (wo): Linear8bitLt(in_features=10240, out_features=4096, bias=False) <--
              (dropout): Dropout(p=0.1, inplace=False)
              (act): NewGELUActivation()
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )

So that one is required.

For decoder and lm_head, I included those because of this line:

https://github.com/huggingface/transformers/blob/9e56aff58a742b48fc8edea8d28d5b80330efbcc/src/transformers/modeling_utils.py#L2319

For this model get_keys_to_not_convert returns ['decoder', 'lm_head']. So I didn't want to change this behavior.

Note that the decoder doesn't actually seem to do anything, because in replace_8bit_linear:

https://github.com/huggingface/transformers/blob/9e56aff58a742b48fc8edea8d28d5b80330efbcc/src/transformers/utils/bitsandbytes.py#L113-L126

this actually only checks the last part of the module name (e.g. wo), but decoder itself is not a linear layer. Not sure if this behavior is intended, or is this a separate bug that replace_8bit_linear should check the full module name?

Does the same fix applies for T5-XL ?

Yes. I ran the same test internally; can confirm fp32 quality == 8bit-with-fix quality != 8bit-without-fix quality for XL.

Thanks!

younesbelkada commented 1 year ago

hi @larsmennen Thanks so much for your detailed answer, everything is clear on my side now. Regarding your point about get_keys_not_convert it is probably a bug, let's fix this in a separate PR later .

20683 is in a good shape IMO. Can you checkout from this branch, apply the patches mentioned in 1&2 and let us know if it works as expected? 🙏

larsmennen commented 1 year ago

@younesbelkada moving back to this thread:

Would you mind opening a PR addressing your suggestions (patch 1 & 2 from the discussion at https://github.com/huggingface/transformers/issues/20287 )?

Yes, happy to. will get that in today or tomorrow.