huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.05k stars 26.3k forks source link

MLPs with `FbgemmFp8Linear` on Llama-405b-FP8 do not handle batch sizes >1 correctly #32868

Closed vgel closed 4 days ago

vgel commented 4 weeks ago

System Info

transformers 4.44.0 torch 2.4.0+cu121 fbgemm_gpu 0.8.0+cu121

Who can help?

@ArthurZucker (also maybe @SunMarc ?)

Information

Tasks

Reproduction

>>> model_name = "meta-llama/Meta-Llama-3.1-405B-FP8"
>>> base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> train_strs = [
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: That',
...  'An AI is talking to the user.\n\nUser: Who are you?\n\nAI: That',
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: I',
...  'An AI is talking to the user.\n\nUser: Who are you?\n\nAI: I',
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: I can'
... ]
>>> encoded_batch = tokenizer(train_strs[:32], padding=True, return_tensors="pt").to(model.device)
>>> out = base_model(**encoded_batch)
...
File python3.10/site-packages/transformers/models/llama/modeling_llama.py:751, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    749 hidden_states = self.post_attention_layernorm(hidden_states)
    750 hidden_states = self.mlp(hidden_states)
--> 751 hidden_states = residual + hidden_states
    753 outputs = (hidden_states,)
    755 if output_attentions:

RuntimeError: The size of tensor a (23) must match the size of tensor b (736) at non-singleton dimension 1
>>> # Note 736 = 32 * 23, the maximum sequence length

After some digging in pdb, I tracked it down to the quantized MLPs:

>>> emd = base_model.model.embed_tokens
>>> emd.shape
torch.Size([32, 23, 16384])
>>> type(base_model.model.model.layers[0].mlp.up_proj)
torch.nn.modules.linear.Linear
>>> base_model.model.model.layers[0].mlp(emd).shape
torch.Size([32, 23, 16384])
>>> type(base_model.model.model.layers[1].mlp.up_proj)
transformers.integrations.fbgemm_fp8.FbgemmFp8Linear
>>> base_model.model.model.layers[1].mlp(emd).shape
torch.Size([736, 16384]) # <-------------------------------- wrong!!

I was able to patch it with this monkeypatch:

>>> class FixedQuantedMLP(torch.nn.Module):
...     def __init__(self, mlp):
...         super().__init__()
...         self.mlp = mlp
... 
...     def forward(self, x):
...         shape = x.shape
...         x = self.mlp(x)
...         return x.reshape(shape) 

>>> def fix_layer_mlp(layer):
...     layer.old_mlp = layer.mlp
...     layer.mlp = FixedQuantedMLP(layer.mlp)

>>> for layer in base_model.model.layers: fix_layer_mlp(layer)

...which made model.generate work as expected.

Expected behavior

The quantized MLP layers should not squish batch size and sequence length together. I suspect these lines are at fault, but I'm not sure:

https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/integrations/fbgemm_fp8.py#L50-L52

SunMarc commented 3 weeks ago

Thanks for the detailed report @vgel ! This is indeed a bug. I forgot that calling view modifies the tensor inplace. Would you like to open a PR to fix this ? As you tested, you just need to reshape the tensor to its original shape just after quantize_fp8_per_row ops.

vgel commented 2 weeks ago

Thanks for the detailed report @vgel ! This is indeed a bug. I forgot that calling view modifies the tensor inplace. Would you like to open a PR to fix this ? As you tested, you just need to reshape the tensor to its original shape just after quantize_fp8_per_row ops.

Sure, just opened a PR!