huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.91k stars 26.27k forks source link

torch.arange use should not use dtype=float for integer ranges, conflicts w/ DS `zero.Init()` #28685

Closed rwightman closed 4 months ago

rwightman commented 7 months ago

System Info

Impacts many versions of transformers up to and including current.

Who can help?

@ArthurZucker @amyeroberts

Information

Tasks

Reproduction

Use a number of transformers models that utilize arange for integer enumerations in the calculation of position embeddings with DeepSpeed zero.Init() and a low precision dtype (float16, bfloat16), and the generated embeddings will differ significantly from intended.

Using Llama as an example t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

The inv_freq.dtype == float32. Single precision float can cover the required integer range for the enumeration (I believe it's in the 2k-8k range for Llama?).

However, when DeepSpeed zero.Init is used the init function patching will override the float dtype passed in with a low precision float dtype, so float32 -> bfloat16 or float16. Thus the integer range that can be represented without significant loss drops down to 256 for bfloat16 or 2048 for float16. DeepSpeed's patching has an exception for integer dtype, it will not cast arange to the low precision float dtype if arange dtype is an int type.

https://github.com/microsoft/DeepSpeed/blob/0dd0c615f8e6c7947ba81a4b0993284da5ec3209/deepspeed/runtime/zero/partition_parameters.py#L245-L246

def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:

    def wrapped_fn(*args, **kwargs) -> Tensor:
        if kwargs.get("device", None) is None:
            kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        tensor: Tensor = fn(*args, **kwargs)
        if tensor.is_floating_point():
            tensor.data = tensor.data.to(target_fp_dtype)

        return tensor

    return wrapped_fn

torch.arange defaults to an integer dtype if start/end/step are ints. In this case though it's best to be explicit to make intent clear, we should explictly set dtype=torch.long (or torch.int64 depending on your tastes). Casting to float should be done after the arange. Additionally, in many position embedding calculation scenarios, it's best to try and keep the calculations in float32 as long as possible, doing final conversion to low precision type at the very end (if that's the dtype of inference or training).

Expected behavior

Use of torch.arange should explicitly set dtype=torch.long (or int64).

Ex: for Llama,

t = torch.arange(self.max_seq_len_cached, device=device).type_as(self.inv_freq)

rwightman commented 7 months ago

Code instances where this either definitely a concern, or likely (depending on ranges involved). https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/llama/modeling_llama.py#L130-L131 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/llama/modeling_llama.py#L140 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/llama/modeling_llama.py#L168 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/llama/modeling_llama.py#L195 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/examples/research_projects/bertabs/modeling_bertabs.py#L265-L266 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/codegen/modeling_codegen.py#L55-L59 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/conditional_detr/modeling_conditional_detr.py#L437-L455 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/conditional_detr/modeling_conditional_detr.py#L496-L509 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/ctrl/modeling_ctrl.py#L47-L60 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deformable_detr/modeling_deformable_detr.py#L494-L495 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deformable_detr/modeling_deformable_detr.py#L620-L621 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deformable_detr/modeling_deformable_detr.py#L1542-L1543 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py#L945-L946 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deta/modeling_deta.py#L404-L405 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deta/modeling_deta.py#L529-L530 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/deta/modeling_deta.py#L1453-L1454 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/detr/modeling_detr.py#L438-L439 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/falcon/modeling_falcon.py#L151-L152 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/falcon/modeling_falcon.py#L180 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/falcon/modeling_falcon.py#L208 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py#L823-L826 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/fsmt/modeling_fsmt.py#L1349-L1351 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/funnel/modeling_funnel.py#L235-L267 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/fuyu/image_processing_fuyu.py#L687-L690 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L547 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L576 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L604 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py#L255 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/gptj/modeling_gptj.py#L60 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/idefics/modeling_idefics.py#L480 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/kosmos2/modeling_kosmos2.py#L776-L780 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/m2m_100/modeling_m2m_100.py#L114- https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mask2former/modeling_mask2former.py#L863-L864 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mask2former/modeling_mask2former.py#L2132-L2133 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/maskformer/modeling_maskformer.py#L1354-L1355 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mega/modeling_mega.py#L172-L174 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mistral/modeling_mistral.py#L109-L110 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mistral/modeling_mistral.py#L109-L110 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mixtral/modeling_mixtral.py#L202-L203 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/mpt/modeling_mpt.py#L69-L70 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/musicgen/modeling_musicgen.py#L129-L131 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/nezha/modeling_nezha.py#L153-L155 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/nllb_moe/modeling_nllb_moe.py#L167-L169 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/oneformer/modeling_oneformer.py#L2803-L2804 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/pegasus_x/modeling_pegasus_x.py#L112-L113 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/persimmon/modeling_persimmon.py#L61-L62 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/persimmon/modeling_persimmon.py#L90-L91 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/persimmon/modeling_persimmon.py#L118-L119 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/phi/modeling_phi.py#L99 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/phi/modeling_phi.py#L128 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/phi/modeling_phi.py#L156 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/qwen2/modeling_qwen2.py#L116 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py#L417 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py#L1024-L1026 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py#L980-L982 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/speech_to_text/modeling_speech_to_text.py#L133-L136 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/speecht5/modeling_speecht5.py#L316-L318 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/speecht5/modeling_speecht5.py#L406-L407 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/swin2sr/modeling_swin2sr.py#L293-L296 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/swinv2/modeling_swinv2.py#L449-L451 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/table_transformer/modeling_table_transformer.py#L374-L375 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/trocr/modeling_trocr.py#L88-L90 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py#L315 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py#L445-L448 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/xglm/modeling_xglm.py#L160-L162 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/src/transformers/models/xlnet/modeling_xlnet.py#L1023-L1038 https://github.com/huggingface/transformers/blob/8278b1538ecc89dad8ebca510a31a86bc8645edb/examples/research_projects/visual_bert/modeling_frcnn.py#L171-L192

rwightman commented 7 months ago

If we look at the original Llama code, this issue is avoided.

https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/model.py#L100-L104

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

For GPT-NeoX, it's also avoided, but problematic in transformers (see above).

https://github.com/EleutherAI/gpt-neox/blob/63991555ec082c8f80c475f851d008193b10008c/megatron/model/positional_embeddings.py#L27-L34

        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
        sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
        if self.precision == torch.bfloat16:
            sinusoid_inp = sinusoid_inp.float()
        sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
        if self.precision == torch.bfloat16:
            sin, cos = sin.bfloat16(), cos.bfloat16()
        emb = torch.cat((sin, cos), dim=-1)
rwightman commented 7 months ago

I believe this is the problem being seen in this issue https://github.com/microsoft/DeepSpeed/issues/4932 and also seeing now this may be a dupe of https://github.com/huggingface/transformers/issues/28596

rwightman commented 7 months ago

Related to this possible concern with the zero.Init() overriding dtype for arange (and I did confirm this is a problem with a test bench), there's also an overlapping issue that's been brought up before in e.g. #25681 but I don't think fully addressed as that improvement focused on rescaling at runtime for larger seq len, this one is due to zero.Init() overriding the device arg for tensor creation fns and having the init done on a non-CPU device.

When a library like DeepSpeed forces the calculation of the cached RoPe sin/cos/freq values onto the GPU it is wrong compared to the CPU calcs due to a rather nasty combo of floating point ops that differ enough to have a significant impact (div, pow, outer product, convert to low precision), ~5e-4 in float16 and 2e-3 eps for Llama. This results in model logit values differing by close to 1.0. This is with the calcs forced to float32 (so explicitly avoiding doing them in low precision), even doing the calculations in double precision is not enough to avoid problematic differences between GPU and CPU.

The only approach that seems viable is ensuring the init of those constants are always done on CPU (requires extra workarounds to prevent DeepSpeed from forcing onto GPU) and then at the very last step before they're used, do the cast to computation dtype. I trialed an approach that's related to an Eleuther workaround in their lib, but it likely has some breaking concerns with other use cases like tracing, etc. https://github.com/microsoft/DeepSpeed/issues/4932#issuecomment-1911277956

EDIT: also think we should be forcing RoPE embeddings to be applied in float32 instead of default computation dtype. I think the original Llama is doing this but transformers is not.

gante commented 7 months ago

First things off: should we run RoPE in FP32? Are buffers a problem?

I've done a quick perplexity benchmark of meta-llama/Llama-2-7b-hf with BF16 (from main) vs BF16 with RoPE being computed in FP32 without buffers (from this commit), adapting the ppl benchmark scripts from this comment.

Here are the results: plot_perplexity_vram

We can see a very tiny PPL upgrade, almost negligible. This comes at the expense of an equally small increase in GPU memory requirements.

plot_latency

On the latency side, we see that going beyond the original context length is much more expensive -- the new sin/cos must to be computed in FP32, which is more expensive.

šŸ‘‰ To me, this indicates that changing RoPE computations to FP32 is not worth it. Happy to do more experiments, if you suspect this logic may be flawed in some settings/other models šŸ¤— (cc @ArthurZucker, we've chatted about this a few days ago)

šŸ‘‰ @rwightman could this difference be more pronounced in DeepSpeed? I have no DeepSpeed experience.

gante commented 7 months ago

Regarding the Deepspeed init issues, going to open a PR to fix it šŸ’Ŗ

rwightman commented 7 months ago

@gante are you sure the perplexity test is representative of a wide enough range of use? In the particular users case, I was testing with their input vectors and the logits are significantly different with embeddings calc in bfloat16. Computing pos embeds in float32 on CPU and applying just the calculated embedding in lower precision wasn't too bad...

Similarly comparing the embedding values themselves, fully calculated on a different device, or calculated in low precision, the differences in the embedding floats is well beyond a range I'd be comfortable with... might be something we'd want to consider allowing the user to make their own tradeoffs via config...

gante commented 7 months ago

@rwightman Absolutely not, the data/model space to test on is too large to measure!

However, since the latency penalty of applying the change in all cases is non-negligible and exposing the ability to recompute the buffer in fp32 adds yet another flag, I'd like to have a reproducible example of a failure case in transformers -- at least to fully understand what's going on. That's why I added the note that I'm fully open to run more experiments, as long as I have some pointers.

There are many competing requests. Sadly, it's hard to find the time to do a proper deep dive to find a failure mode šŸ¤—

rwightman commented 7 months ago

@gante understood, with the cached values the runtime latency for the 'okay' case I described should be non-existent though... namely,

  1. calculate the values on the cpu, in float32 as a rule
  2. cast to the usage dtype, eg store sin/cos embeds in bfloat16 on the target device)

There would be no memory or runtime overhead (other than when a new seq length is switched to), but the pos embed values would be significantly closer to their intended values.

ArthurZucker commented 7 months ago

@gante from my personal test, changing the inv_freq to float32 can increase performances on MMLU of about 20points. There are a few things to test:

I don't think perplexity is something we should ever use for these kind of tests, but rather proper generations / bench:

rwightman commented 7 months ago

Also, perplexity is an average score, I'm not overly familiar with the typical test data, but I assume it's probably not pushing corner cases? well formed?

What I was looking at comparing some forward pass outputs with different cpu vs gpu, bfloat16 vs float16 vs float32 precisions for computing those sin/cos embeds, the differences were significant. The logit (output of the model) differences could also be quite significant, but I was looking at worst case, not average logit diffs, the mean is pretty unintersting, most logits were close, but the worst ones were well outside the range I'd consider reasonable... and it's the worst case that cause networks to blow up...

gante commented 7 months ago

Perhaps my message above came across incorrectly šŸ˜…

I trust what you wrote, that it should be computed in FP32. I meant that we should have a few concrete failure examples to a) test against and prevent regressions and b) document why we made modeling changes (especially ones that increase HW requirements). Since I had yet to come across this particular issue and wasn't aware of the type of numerical issue (systemic drift vs infrequent failure), a few extra pointers were needed to speed things up. A reproducible script would be even better. We do request this on our contributors' issues šŸ¤—

Now I have a more precise target, which facilitates search: infrequent large differences. I'm going to dig up a clear example and open a PR to add the corresponding RoPE fix and tests.

gante commented 7 months ago

An example of a failure case is below.

from transformers import AutoModelForCausalLM
import torch

model_1 = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/tiny-random-LlamaForCausalLM",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model_2 = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/tiny-random-LlamaForCausalLM",
    device_map="auto",
).to(torch.bfloat16)

# `torch_dtype=...` doesn't cast explicitly set types, `.to(...)` does
assert model_1.model.layers[0].self_attn.rotary_emb.inv_freq.dtype == torch.float32
assert model_2.model.layers[0].self_attn.rotary_emb.inv_freq.dtype == torch.bfloat16

# sequence length smaller than the initialized length (2048) -> no problem
input_ids = torch.randint(0, 32000, (1, 1024)).to("cuda")
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
assert torch.allclose(model_1_out.logits, model_2_out.logits)

# sequence length larger than the initialized length (2048) -> problem
# why? larger than initialized length -> sin/cos have to be recomputed -> the different type of non-permanent buffers
# will have an impact
input_ids = torch.randint(0, 32000, (1, 2049)).to("cuda")
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
assert torch.allclose(model_1_out.logits, model_2_out.logits)

It is extremely easy to find the bug when .to() is used instead of torch_dtype -- but only after the original sequence length on main, due to the existing order of operations. Anything that fiddles with types at initialization time (like DeepSpeed) will run into the problem immediately, even before breaking the sequence length.

The same perplexity script can also find the problem, using the .to() method to cast the model: plot_perplexity_vram

šŸ‘‰ moving forward with the PR to ensure this OP stays in FP32 and these artefacts are no longer present

ArthurZucker commented 7 months ago

super nice šŸ¤—

ArthurZucker commented 6 months ago

Also linked to #29285

github-actions[bot] commented 4 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 4 months ago

I think this was fixed! For llama at least and gemma šŸ¤—

gante commented 4 months ago

Yes -- the missing related bits are to make the logits of our llama == logits of the original llama (https://github.com/huggingface/transformers/pull/28837, which I'm tracking -- needs to be picked up from a fresh PR)

Closing this issue since it is sorted :)