huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
9.07k stars 1.07k forks source link

fp8 weight load failed IndexError: list index out of range #2409

Open icyxp opened 3 months ago

icyxp commented 3 months ago

System Info

2024-08-13T06:17:44.049654Z ERROR shard-manager: text_generation_launcher: Shard complete standard error output:

2024-08-13 06:17:41.545 | INFO     | text_generation_server.utils.import_utils:<module>:75 - Detected system cuda
/opt/conda/lib/python3.10/site-packages/text_generation_server/utils/sgmv.py:18: UserWarning: Could not import SGMV kernel from Punica, falling back to loop.
  warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:159: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:232: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout):
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py:508: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py:567: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
[rank1]: Traceback (most recent call last):

[rank1]:   File "/opt/conda/bin/text-generation-server", line 8, in <module>
[rank1]:     sys.exit(app())

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 109, in serve
[rank1]:     server.serve(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 274, in serve
[rank1]:     asyncio.run(

[rank1]:   File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
[rank1]:     return loop.run_until_complete(main)

[rank1]:   File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
[rank1]:     return future.result()

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 229, in serve_inner
[rank1]:     model = get_model_with_lora_adapters(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 1223, in get_model_with_lora_adapters
[rank1]:     model = get_model(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 780, in get_model
[rank1]:     return FlashCausalLM(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 896, in __init__
[rank1]:     model = model_class(prefix, config, weights)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 528, in __init__
[rank1]:     self.model = FlashLlamaModel(prefix, config, weights)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 418, in __init__
[rank1]:     FlashLlamaLayer(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 346, in __init__
[rank1]:     self.self_attn = FlashLlamaAttention(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 166, in __init__
[rank1]:     self.query_key_value = load_attention(config, prefix, weights, index)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 94, in load_attention
[rank1]:     base_layer = TensorParallelColumnLinear.load_multi(

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/tensor_parallel.py", line 173, in load_multi
[rank1]:     weight = weights.get_multi_weights_col(prefixes, dim=dim)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/utils/weights.py", line 373, in get_multi_weights_col
[rank1]:     return self.weights_loader.get_multi_weights_col(self, prefixes, dim)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/fp8.py", line 140, in get_multi_weights_col
[rank1]:     scale = [

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/fp8.py", line 141, in <listcomp>
[rank1]:     weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/text_generation_server/utils/weights.py", line 270, in get_sharded
[rank1]:     size = slice_.get_shape()[dim]

[rank1]: IndexError: list index out of range

Information

Tasks

Reproduction

none

Expected behavior

none

icyxp commented 3 months ago

Use this project to convert fp8: https://github.com/neuralmagic/AutoFP8

icyxp commented 3 months ago

They are not the same type of problem. Mine is a FP8 load problem, and his is a marlin problem. @drbh

config has "activation_scheme": "dynamic":

model.layers.0.mlp.down_proj.weight              < F8_E4M3
model.layers.0.mlp.down_proj.weight_scale        < F32