AniZpZ / AutoSmoothQuant

An easy-to-use package for implementing SmoothQuant for LLMs
MIT License
82 stars 7 forks source link

Tensor shape error when loading quant Llama-2-70B #14

Open MingLin-home opened 8 months ago

MingLin-home commented 8 months ago

Hello! Thanks for the nice work!

I want to quantize Llama-2-70B. I was able to export the quantized model without any error. However, when I test the model:

python test_model.py \
  --model-path=$quant_model_path \
  --tokenizer-path=$model_path \
  --model-class=llama --prompt="something to say"

I encounter this error:

/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Loading checkpoint shards:   0%|                                                                                    | 0/15 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/mmilin/projects/ming_benchmark_vllm/AutoSmoothQuant/autosmoothquant/examples/test_model.py", line 62, in <module>
    main()
  File "/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/mmilin/projects/ming_benchmark_vllm/AutoSmoothQuant/autosmoothquant/examples/test_model.py", line 40, in main
    model = Int8LlamaForCausalLM.from_pretrained(args.model_path, quant_config, device_map="sequential")
  File "/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3706, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/modeling_utils.py", line 4116, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/transformers/modeling_utils.py", line 778, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/home/mmilin/projects/ming_benchmark_vllm/autosmq-venv/lib/python3.9/site-packages/accelerate/utils/modeling.py", line 345, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([1024, 8192]) in "weight" (which has shape torch.Size([8192, 8192])), this look incorrect.

BTW, I was able to convert and load Llama-2-7b model without any error. Any idea how to fix it? Looks like group attention related.

Many thanks ahead!

AniZpZ commented 8 months ago

Hi! We haven't encountered this problem before. Could you please post your config.json for both models? And which linear layer does the shape mismatch exist?

MingLin-home commented 7 months ago

Hi! We haven't encountered this problem before. Could you please post your config.json for both models? And which linear layer does the shape mismatch exist?

Sorry for the late reply. My config.json:

{
    "qkv": "per-tensor",
    "out": "per-token",
    "fc1": "per-tensor",
    "fc2": "per-token"
  }

BTW, I am able to load the export Llama-2 70B model into vllm-w8a8 now. So I suspect this issue only exists in AutoSmoothQuant repo.

gesanqiu commented 3 months ago

@MingLin-home Hi, have you tackled this issue? I had the same problem.