Closed jimmy-marmalade closed 1 year ago
cc @younesbelkada
Hi @jimmy-marmalade
Thanks a lot for raising this point. Note that int8
quantization is done in 2 stages, it first converts the model in float16
and uses the fp16
model to quantize it in 8bit
. If you try to load and run the model in fp16
you also get gibberish output:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("./flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("./flan-t5-xxl", torch_dtype=torch.float16, device_map="auto")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(input_ids, max_length=512)
print(tokenizer.decode(outputs[0]))
>>> <pad> How old are die Sie? Ihr Mutter?tat?ztlich, rezult Interesse: restriction = = = = ...
I suspect there is something wrong with bf16
to fp16
conversion for this specific model and for xl
model too.
@stas00 do have you any intuition on why the int8 conversion (so the underlying fp16 conversion) worked well for bloom-176 and not here? 🙏
Thanks!
Did the bf16 model weights have large values resulting in overflow when used under fp16? bf16 to fp16 conversion is a huge issue with almost every model we have seen - .e.g all large t5 and derivative models. Seeing that this is a T5 derivative it's almost certainly related, you can see the rich discussion here: https://github.com/huggingface/transformers/pull/10956 and possible workarounds to try.
Probably talk to @TimDettmers and ask if perhaps there could be bf16/int8 variation for bf16 models in BNB?
T5 doesn't work in FP16 because the softmaxes in the attention layers are not upcast to float32. @younesbelkada if you remember the fixes done in BLOOM/OPT I suspect similar ones would fix inference in FP16 for T5 :-)
but why are we even going through FP16 to do quanitization of a bf16 model? why can't this be done directly in the original dtype?
note: interestingly deepspeed-inference also converts the model to fp16 to do quantization.
Thank you very much for all the pointers
T5 doesn't work in FP16 because the softmaxes in the attention layers are not upcast to float32. @younesbelkada if you remember the fixes done in BLOOM/OPT I suspect similar ones would fix inference in FP16 for T5 :-)
I think that T5 already upcasts the softmax to fp32
. I suspected that the overflow might come from the addition to positional bias in the line before but did not helped. I also tried to upcast the lm_logits and the hidden_states before the lm_head to fp32
but did not helped too.
In addition, I also printed the hidden states at every stage, checking whether it contains any nan
or inf
. This was always set to False
.
I will investigate more by reading in deep details @stas00 PR
I think the most efficient solution is to try to see where the overflow comes, and force the operation to be done in fp32
in this operation.
@stas00 using your tool detect_overflow
here, and I am flagging an overflow starting from a certain layer (from layer 6
- because of inf
). (btw I also tried the autocast
solution but seems to not work for inference, I have seen on the issue that it might work for some situations for inference, but sadly not in this case :/ )
Then the hidden states gets constantly overflowed and clamped at each layer. I guess that the accumulation of clamping at various stages introduces these errors.
I am wondering if we can find a proper workaround for that, is clamping the right solution? My intuition is that the model gets confused at some point since clamping n
times the hidden states will yield to completely different results than the bf16
hidden states.
Also, let's say you have flagged the layer responsible of the overflow. You can then do the operation there in fp32
. But the under/overflow will be still present since you need to cast the results back in fp16
right?
There was one more scaling hack posted here if you want to try it. https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628
In general bf16-pretrained models ought to run under bf16 or fp32 regimes, as fp16 and bf16 are very incompatible dtypes. It's not as bad if you were to go from fp16 to bf16 as you'd only lose precision, and it'd only impact quality, but not the other way around (overflow).
So we should raise this question with the BNB and perhaps deepspeed-inference developers, at least to have an understanding of why both require fp16 and won't support bf16.
@TimDettmers, @RezaYazdaniAminabadi - is there a way for your libraries to work with bf16 dtype, so that the bf16-pretrained models won't overflow during inference? Thank you.
Thanks everyone for the discussion and work!
Are there any possible workarounds that I could implement as an end user?
There was one more scaling hack posted here if you want to try it. #14189 (comment)
In general bf16-pretrained models ought to run under bf16 or fp32 regimes, as fp16 and bf16 are very incompatible dtypes. It's not as bad if you were to go from fp16 to bf16 as you'd only lose precision, and it'd only impact quality, but not the other way around (overflow).
So we should raise this question with the BNB and perhaps deepspeed-inference developers, at least to have an understanding of why both require fp16 and won't support bf16.
@TimDettmers, @RezaYazdaniAminabadi - is there a way for your libraries to work with bf16 dtype, so that the bf16-pretrained models won't overflow during inference? Thank you.
The main reason the weights are converted to half on DeepSpeed-side is that the kernels are only working with fp16 values. However, we are looking into some of these limitations and will resolve them soon. The other part is that we can quantize from the original bf16 checkpoint and resolve some of the overflow issue due to different data-precision of fp16 vs bf16.
Wonderful!
So it looks like @jimmy-marmalade can try out your solution (Deepspeed-Inference) once you have something working in bf16/int8, Reza, and hopefully this will unblock them.
Thanks @RezaYazdaniAminabadi is there an issue I can watch to keep track of progress.
@younesbelkada (cc @thomwolf who gave the inspiration for the workaround; cc @stas00 , @sgugger ):
@navjotts and myself had a look at this and found a workaround.
As already concluded in ticket above, the bf16->fp16 conversion is generally incompatible. We ran the detect_overflow
as well (great tool, thanks @stas00 !) and found generally we got overflows in the dense part of layer 7 in the encoder, specifically in the wo
operation of T5DenseGatedActDense
.
We implemented a hacky workaround to keep wo
in fp32, cast its input to fp32 and then leave it in fp32 until after the T5LayerNorm
. At the end of the norm we cast back to fp16. All fp16 linear modules (i.e. everything except the wo
) can then use the 8-bit quantization. The cast back to fp16 is not lossless ofcourse, but we've generally found it to perform equivalent. We haven't spotted any difference in output so far.
We made 3 changes:
T5DenseGatedActDense.forward
:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(
hidden_states.to(torch.float32)
) # PATCH: Cast to float32, as self.wo is casted to float32 w/ patch 3 below
return hidden_states
T5LayerNorm.forward
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return (self.weight * hidden_states).to(
torch.float16
) # PATCH: Cast back to float16 for compatibility w/ next layer. This is not lossless.
_load_state_dict_into_meta_model
if param_name.endswith(
'.wo.weight'
): # PATCH: For the wo weights of the dense layers, keep them in float32, others will get converted to float16 as this is a requirement for the LLM 8-bit quantization.
param = param.to(torch.float32) # PATCH
else: # PATCH
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)
and then when instantiating the model from pretrained we set:
load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo']
I'm not sure what a good way would be to get this into transformers
though / if that would even be a good idea given this is quite hacky, curious for your thoughts. For patch 3, if we could add an option to specify an exclude_list for the conversion to float16, that would remove the need for that patch. Then the layers can be adapted at model-level.
Hi @larsmennen (and cc @thomwolf ) Thanks for this great investigation and amazing fix - I also believe that this approach is the best fix so far for this problem. Thanks so much for working on this as it will enable using these modeis in a more accessible way. I see 2 workaround for that
1- The fix should be applied for 8bit models only, in this case, I think that we can perfectly have an additional flag load_in_8bit_fp32_modules = ["wo"]
and apply a patch similar to your point 3.
. For the points 2
and 1
we can probably have a hot fix as you suggested but I would wait for @sgugger and/or @stas00 to hear what they think about that
2- We add an additional flag, regardless if the model is loaded in 8bit or no, since this could fix the issue with T5-fp16 too, with a flag similar than above keep_in_fp32_modules=["wo"]
that is activated only for half precision models (and 8bit models too). But again we'll probably need the hotfixes from 1
&2
.
Few questions:
load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo']
note that with your fix decoder
and lm_head
will be kept to their native dtype
. In the case you are calling load_in_8bit=True
, we first cast all the weights in fp16 therefore load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo']
is probably not needed as wo
is "force-casted" in fp32 and lm_head
is always detected to be kept in its original precision. Can you double check that? 🙏 Thanks a lot for the clear explanations and the deep dive!
what about users on bf16 hardware that don't want to waste memory casting to fp32 since the modeling code works just fine when using bf16 mixed precision?
I think if this is done it should only be done for fp16 mixed precision.
Also please note that we automatically use apex
's faster layernorm when it's found, so T5LayerNorm.forward
workaround will apply only if it's not found. i.e. you may want to disable the swap-in of the faster version.
I think if this is done it should only be done for fp16 mixed precision.
Yes indeed that's a very good point! (cc @younesbelkada)
Also what about users pre-training their own model in fp16, the proposed change will negatively impact them as well, as the current modeling code should work just fine for them.
IMHO, the safest approach would be to leave the current code alone and have a flag that activates workaround solutions for those who need them.
Additionally, I remind you that there were other workarounds proposed that don't use any additional memory and use a scaling factor instead that moves the weights into a safe-to-fp16 numerical range. https://github.com/huggingface/transformers/pull/10956#issuecomment-961030728
Also what about users pre-training their own model in fp16, the proposed change will negatively impact them as well, as the current modeling code should work just fine for them.
The main goal of having T5 in the library is to support the corresponding pretrained models as best as possible. All of T5, FlanT5 and T0 checkpoints have been trained in bfloat16, so changing the code to support fp16 inference is for the better for the larger community. If this slows down an edge case, the user can just adapt the line of code in the modeling file to suit their need (that's why the library is not modular and with a strict one file per model policy after all :-) ).
One could argue that this breaks backward compatibility since suddenly the model requires more memory to operate than when it was originally released.
If the belief is that the majority X% of users will benefit from such BC breaking change I think it'd at least be nice to have a flag for the small Y% to be able to retain what they have been using w/o needing to manually change the code.
Might be much simpler to clone this to models/t5-bf162fp16
apply all the proposed patches and thus have 2 versions - one for normal use and one for the originally unintended bf162fp16 use.
Thanks for the quick replies all!
@younesbelkada to answer your questions:
When using load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo'] note that with your fix decoder and lm_head will be kept to their native dtype. In the case you are calling load_in_8bit=True, we first cast all the weights in fp16 therefore load_in_8bit_skip_modules=['decoder', 'lm_head', 'wo'] is probably not needed as wo is "force-casted" in fp32 and lm_head is always detected to be kept in its original precision. Can you double check that? 🙏
Regarding need for wo
: if we don't pass it in, then it is not ignored from the conversion of the linear layers to 8bit, and an autocast is applied:
/root/venv/lib/python3.7/site-packages/bitsandbytes/autograd/_functions.py:231: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
and then resulting model:
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
--> (wo): Linear8bitLt(in_features=10240, out_features=4096, bias=False) <--
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
So that one is required.
For decoder
and lm_head
, I included those because of this line:
For this model get_keys_to_not_convert
returns ['decoder', 'lm_head']
. So I didn't want to change this behavior.
Note that the decoder
doesn't actually seem to do anything, because in replace_8bit_linear
:
this actually only checks the last part of the module name (e.g. wo
), but decoder
itself is not a linear layer. Not sure if this behavior is intended, or is this a separate bug that replace_8bit_linear
should check the full module name?
Does the same fix applies for T5-XL ?
Yes. I ran the same test internally; can confirm fp32 quality == 8bit-with-fix quality != 8bit-without-fix quality for XL.
Thanks!
hi @larsmennen
Thanks so much for your detailed answer, everything is clear on my side now. Regarding your point about get_keys_not_convert
it is probably a bug, let's fix this in a separate PR later .
@younesbelkada moving back to this thread:
Would you mind opening a PR addressing your suggestions (patch 1 & 2 from the discussion at https://github.com/huggingface/transformers/issues/20287 )?
Yes, happy to. will get that in today or tomorrow.
System Info
transformers
version: 4.25.0.dev0Who can help?
@patrickvonplaten
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Running the English to German example:
produces expected output:
Loading in 8-bit and running:
results in output more nonsensical than I'd expect:
Expected behavior
I expected close or approximate output between the original output and the 8-bit output. This was the provided INT8 code snippet so expected output to be sensible for task.