microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.52k stars 4.03k forks source link

Unable to load fp6 QuantizationConfig via DeepSpeedInference #5247

Closed Qubitium closed 5 months ago

Qubitium commented 5 months ago

Describe the bug

Unable to use/test fp6 quantization in deepspeed 0.14 in inference mode on a GPT2 model. There is little documentation on usage right so not sure if I have the wrong init method. I am passing quant based on the changes on the commited fp6 PR. by @loadams

To Reproduce

Crash Error

[2024-03-09 02:14:12,754] [INFO] [logging.py:96:log_dist] [Rank -1] DeepSpeed info: version=0.14.0, git-hash=unknown, git-branch=unknown
> Model file corrupted, type: <class 'pydantic.v1.error_wrappers.ValidationError'>, message: 1 validation error for DeepSpeedInferenceConfig
quant -> quantization_mode
 extra fields not permitted (type=value_error.extra)

Code

 qconfig = deepspeed.inference.v2.config_v2.QuantizationConfig(quantization_mode="wf6af16")
 engine = deepspeed.init_inference(
      model=gpt2_model
      dtype=torch.float16,
      replace_with_kernel_inject=True,  
      max_out_tokens=2048, 
      quant=qconfig,
  )

But DeepSpeedInferenceConfig code has quant field:

class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
    """ Sets parameters for DeepSpeed Inference Engine. """

    replace_with_kernel_inject: bool = Field(False, alias="kernel_inject")
    """
    Set to true to inject inference kernels for models such as, Bert, GPT2,
    GPT-Neo and GPT-J.  Otherwise, the injection_dict provides the names of two
    linear layers as a tuple:
    `(attention_output projection, transformer output projection)`
    """

    dtype: DtypeEnum = torch.float16
    """
    Desired model data type, will convert model to this type.
    Supported target types: `torch.half`, `torch.int8`, `torch.float`
    """

    tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
    """
    Configuration for tensor parallelism used to split the model across several
    GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
    """

    enable_cuda_graph: bool = False
    """
    Use this flag for capturing the CUDA-Graph of the inference ops, so that it
    can run faster using the graph replay method.
    """

    use_triton: bool = False
    """
    Use this flag to use triton kernels for inference ops.
    """

    triton_autotune: bool = False
    """
    Use this flag to enable triton autotuning.
    Turning it on is better for performance but increase the 1st runtime for
    autotuning.
    """

    zero: DeepSpeedZeroConfig = {}
    """
    ZeRO configuration to use with the Inference Engine. Expects a dictionary
    containing values for :any:`DeepSpeedZeroConfig`.
    """

    triangular_masking: bool = Field(True, alias="tm")
    """
    Controls the type of masking for attention scores in transformer layer.
    Note that the masking is application specific.
    """

    moe: Union[bool, DeepSpeedMoEConfig] = {}
    """
    Specify if the type of Transformer is MoE. Expects a dictionary containing
    values for :any:`DeepSpeedMoEConfig`.
    """

    quant: QuantizationConfig = {}
    """
    NOTE: only works for int8 dtype.
    Quantization settings used for quantizing your model using the MoQ.  The
    setting can be one element or a tuple. If one value is passed in, we
    consider it as the number of groups used in quantization. A tuple is passed
    in if we want to mention that there is extra-grouping for the MLP part of a
    Transformer layer (e.g. (True, 8) shows we quantize the model using 8
    groups for all the network except the MLP part that we use 8 extra
    grouping). Expects a dictionary containing values for
    :any:`QuantizationConfig`.
    """

Expected behavior

Load success

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
evoformer_attn ......... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/root/miniconda3/lib/python3.11/site-packages/torch']
torch version .................... 2.2.1+cu121
deepspeed install path ........... ['/root/miniconda3/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.14.0, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 2.2, cuda 12.1

System info (please complete the following information):

xiaoxiawu-microsoft commented 5 months ago

hmmm, please start the test with our deepspeed-mii: https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-fp6/03-05-2024/README.md#4-how-to-begin-with-deepspeed-fp6--

loadams commented 5 months ago

@Qubitium - can you please share your pydantic version from pip list?