Snowflake-Labs / snowflake-arctic

Apache License 2.0
511 stars 41 forks source link

Meet error in serving with huggingface inference tutorial #16

Closed JF-D closed 4 months ago

JF-D commented 4 months ago

Hi, Arctic team, Great work! I followed the Huggingface Inference Tutorial to do the inference. But I met the following error:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [24:34<00:00,  7.56s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 28, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 347, in pre_forward
    set_module_tensor_to_device(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([7168, 7168]) in "weight" (which has shape torch.Size([100352, 516])), this look incorrect.

Can you help me resolve this? Thanks a lot!

jeffra commented 4 months ago

Hi @JF-D! Thanks for trying this out. Can you tell me a bit more about your setup? Specifically:

  1. total number and type of GPUs
  2. transformers and deepspeed versions
  3. Did you make any changes to the example code? If so can you paste it here?
JF-D commented 4 months ago
  1. I am using 8xA100 GPUs.
  2. transformers==4.40.0.dev0, deepspeed==0.14.2 I followed the following instructions to install the deps:
    
    # we recommend setting up a virtual environment for this
    virtualenv arctic-venv
    source arctic-venv/bin/activate

faster ckpt download speed

pip install huggingface_hub[hf_transfer]

clone vllm repo and checkout arctic branch

git clone -b arctic https://github.com/Snowflake-Labs/vllm.git cd vllm pip install -e .

clone Hugging Face and checkout arctic branch

git clone -b arctic https://github.com/Snowflake-Labs/transformers.git

install deepspeed

pip install deepspeed>=0.14.2

3. I didn't change the example code. The full code is as listed,

import os

enable hf_transfer for faster ckpt download

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch from transformers import AutoModelForCausalLM, AutoTokenizer from deepspeed.linear.config import QuantizationConfig

tokenizer = AutoTokenizer.from_pretrained( "Snowflake/snowflake-arctic-instruct", trust_remote_code=True )

quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained( "Snowflake/snowflake-arctic-instruct", low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto", ds_quantization_config=quant_config, max_memory={i: "150GiB" for i in range(8)}, torch_dtype=torch.bfloat16)

messages = [{"role": "user", "content": "What is 1 + 1 "}] input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(input_ids=input_ids, max_new_tokens=20) print(tokenizer.decode(outputs[0]))



BTW, loading checkpoints takes ~30min on my server, it's soooo long. 
jeffra commented 4 months ago

Excellent, one quick follow-up question before diving into the other details. Are these 40GB or 80GB A100s?

w.r.t. slow load times, we are working on uploading pre-quantized checkpoints to HF. Hopefully that will help reduce the load times a bit.

JF-D commented 4 months ago

They are 80GB A100s. I think with the quantization config, I should be able to run a simple example.

jeffra commented 4 months ago

Gotcha, yeah I think 8xA100-80GB should work here. We have not tested this exactly since I don't have immediate access to this hardware. I have seen that error message previously due to some tensors being moved to CPU by device_map="auto". This shouldn't happen if there's enough memory on the GPUs for everything though, which we have confirmed is the case with 8xH100-80GB.

jeffra commented 4 months ago

Also, if you haven’t already can a you try changing q_bits=6 in the quant config?

JF-D commented 4 months ago

Ok! Let me have a try and then get back to you.

JF-D commented 4 months ago

Unfortunately, setting q_bits=6 meets the same error

[2024-04-29 10:42:13,954] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
Using /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121/fp_quantizer/build.ninja...
Building extension module fp_quantizer...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fp_quantizer...
Time to load fp_quantizer op: 0.3470158576965332 seconds
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [03:56<00:00,  1.22s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 28, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 347, in pre_forward
    set_module_tensor_to_device(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([7168, 7168]) in "weight" (which has shape torch.Size([100352, 516])), this look incorrect.
sfc-gh-reyazda commented 4 months ago

Hi @JF-D,

Can you please try this PR? Thanks. Reza

JF-D commented 4 months ago

Thanks! @sfc-gh-reyazda

I tried the PR you mentioned, and met the following error,

[2024-04-30 13:41:09,352] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
Using /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /mnt/afs/jfduan/.cache/torch_extensions/py310_cu121/fp_quantizer/build.ninja...
/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module fp_quantizer...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fp_quantizer...
Time to load fp_quantizer op: 0.3507523536682129 seconds
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [14:56<00:00,  4.60s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:31999 for open-end generation.
Traceback (most recent call last):
  File "/mnt/afs/jfduan/LLMInfer/snowflake-arctic/inference/hf_infer.py", line 29, in <module>
    outputs = model.generate(input_ids=input_ids, max_new_tokens=20)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1708, in forward
    outputs = self.model(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1397, in forward
    layer_outputs = decoder_layer(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 1087, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/LLMInfer/transformers-arctic/src/transformers/models/arctic/modeling_arctic.py", line 808, in forward
    query_states = self.q_proj(hidden_states)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/linear/quantization.py", line 137, in forward
    return F.linear(input, self.weight.dequantized(), self.bias)
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/linear/quantization.py", line 73, in dequantized
    return self.quantizer.dequantize(self.data,
  File "/mnt/afs/jfduan/env/miniconda3/envs/arctic/lib/python3.10/site-packages/deepspeed/ops/fp_quantizer/quantize.py", line 89, in dequantize
    assert (self.orig_dtype is not None), \
AssertionError: [De-quantization Error]: you need to call quantize before dequantizing!
sfc-gh-reyazda commented 4 months ago

This is very strange! This means that the quantizer with which you are trying to dequantize the weight does not have the self.orig_dtype set properly! and it only means that the quantizer of that weight was never called (otherwise, this should have been set here)! So, this suggests to me that we are probably using different versions of transformers as I am not able to repro the same issue as you see. I tried this on an older commit of snowflake-lab/transformers: 6b1fe691bf8c34318f1beb5124db1162d93f047e which branch/commit are you using?

JF-D commented 4 months ago

I checked the version of transformers, the latest commit is the same with you tried (6b1fe691bf8c34318f1beb5124db1162d93f047e).

JF-D commented 4 months ago

I find the error. When trying to quantize the weights, DS found the tensor is on meta device instead of GPU, so the tensor is not quantized (here).

But I think I should be able to run arctic model with FP8 quantization and 8x80GB A100. It's quite strange. Maybe something wrong with huggingface accelerate?

JF-D commented 4 months ago

I guess I find the reason. The transformers cannot get aware the deepspeed quantization config, so it gives a wrong auto placement with accelerate (here).

sfc-gh-reyazda commented 4 months ago

how about explicitly specifying it:


quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained(
    "/checkpoint/2b-v30",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto",
    ds_quantization_config=quant_config,
    max_memory={i: "150GiB" for i in range(8)},
    torch_dtype=torch.bfloat16)
JF-D commented 4 months ago

I have set the config as the following,

quant_config = QuantizationConfig(q_bits=8)

model = AutoModelForCausalLM.from_pretrained(
    "Snowflake/snowflake-arctic-instruct",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto",
    ds_quantization_config=quant_config,
    max_memory={i: "80GiB" for i in range(8)},
    torch_dtype=torch.bfloat16)

The transformers cannot capture the quantization config set by ds_quantization_config. Notably, I am using 80GB A100, so I set the max_memory to 80GB. This leads to a wrong mapping. I can run the example by setting max_memory to 160GB to mimic the quantization effect.

jeffra commented 4 months ago

Ohh yes, you have to set the max_memory to ~2x the actual memory available so that accelerate will do the right thing. To confirm, you are running successfully now after making this change right?

We are actively working on adding deepspeed quantization support into HFQuantizer instead of this current way. This should smooth out this path once it's live.

JF-D commented 4 months ago

Yes! I can run successfully after setting max_memory to ~2x the actual memory available. Thanks for the help!

jeffra commented 4 months ago

Excellent, glad to hear :) I'll close this for now then, please re-open if there are remaining issues though.