Closed Qubitium closed 6 months ago
Apologies for the delayed response, as we're currently in a 5-day public holiday.
Indeed, this is a significant issue overlooked by autoround, stemming from the root cause you've already highlighted: the lm-head shares weights with the embedding , a characteristic present in many models but absent in Llama3.
Option 1: A straightforward approach is to disable LM-head quantization if its weights are tied to the embedding. Embedding operations are typically fast, and tied weights will complicate algorithmic processes.
Option 2: Set tied weights to false, allowing the embedding to retain float weights while quantizing the LM-head.
Option 3: Quantize the shared weights. To align with linear quantization, we need to configure the embedding layer's group dimension as the output channel and ensure the kernel supports the embedding layer.
@wenhuach21 Thanks for the response despite on vacation!
Option 1: A straightforward approach is to disable LM-head quantization if its weights are tied to the embedding. Embedding operations are typically fast, and tied weights will complicate algorithmic processes.
This may be the best temp solution until option 3 is implemented or a big fat warning.
Option 2: Set tied weights to false, allowing the embedding to retain float weights while quantizing the LM-head.
This is the fix I implemented for vllm's OPT model loading code. It works but really doesn't make too much real-world sense as vram usage would actually increase due to the fact the lm_head quantized layer now exists separate side-by-side of the also loaded non-quantized embedding layer. Ok for testing but should be disabled/user-warned in usage.
Option 3: Quantize the shared weights. To align with linear quantization, we need to configure the embedding layer's group dimension as the output channel and ensure the kernel supports the embedding layer.
This may be the best long-term option. Option 2 is not great as it bloats vram and goes against the goal of deploying quantization.
@Qubitium Yes, Option 3 seems to be the most optimal choice, however, it may present a challenge to the algorithm. we have chosen Option 1 as a temporary solution.
While testing for OPT with
quant_lm_head=True
, here are the result weights post quantize:weight keys: ['lm_head.g_idx', 'lm_head.qweight', 'lm_head.qzeros', 'lm_head.scales', 'model.decoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', ...
model.decoder.embed_tokens.weight
is not quantized butlm_head
is. Unforutnately vllm model code and maybe hf transformer also ignores this lm_head layer in weight load? I confirmed this for vllm but not 100% sure for transformer.But opt's lm_head is actually the same as (soft lnked)
model.decoder.embed_tokens
in code in vllm and appears to be true in transformers as well. Checked original weights and lm_head exists in weights but size/values exactly same as embed_tokens so model coders appears to think lm_head should be ignored on load.https://github.com/huggingface/transformers/blob/0ae789e04330e15a90e34cd723c851a8ab8d7ec5/src/transformers/models/opt/modeling_opt.py#L1001
In vllm's model loading code for OPT, the
lm_head
weights are skipped and soft-linked to embeddings. This appears to be the same for hf transformers as well.https://github.com/vllm-project/vllm/blob/26f2fb51133c85ad8a57a87c8037f750dda757f4/vllm/model_executor/models/opt.py#L288
So my naive question is who is correct? Autoround correctly finding and quantizing the
lm_head
layer but this layer is actually ignored by model loaders? ={This is relation to the testing I am doing for vllm PR: https://github.com/vllm-project/vllm/pull/4442#issuecomment-2085491133
This becomes an issue loading the quant as vllm and completedly skip
lm_head
layers (pre or post-quant) since I assume the model code writer assumed why load the same equivalent weights twice when tensor size and values are exactly the same.I am new to all the layers/modules so forgive me if my question itself is based on false premises. Thank you! I hope to have intel/autoround model support merged into vllm soon.
Here is the original weights before quantization:
https://huggingface.co/facebook/opt-125m
So in original OPT-127M model weights, the
model.decoder.embed_tokens.weight
andlm_head.weight
both exists and size and even values of all tensors are exactly the same!@robertgshaw2-neuralmagic Is this a bug in vllm OPT model code? Why is it skipping
lm_head
layer when it actually exists (even though it is an duplicate of embed_tokens)?@wenhuach21 @WeiweiZhang1