huggingface / optimum-nvidia

Apache License 2.0
867 stars 86 forks source link

Incorrect tensorrt_llm config class initialization #90

Open Wojx opened 6 months ago

Wojx commented 6 months ago

Tensorrt_llm version: 0.9.0.dev2024030500 optimum-nvidia version: 0.1.0b3

I try to run llama 2 model, but model init function has a little bug. Model loading throws an error: model = AutoModelForCausalLM.from_pretrained("Voicelab/trurl-2-13b")

trt_config = LlamaConfig(
TypeError: PretrainedConfig.__init__() missing 1 required positional argument: 'quantization'

I found that AutoModelForCausalLM calls LlamaConfig:

class LlamaConfig(TensorRTConfig):
    r"""
    This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the LLaMA-7B.

    Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the
    documentation from [`TensorRTConfig`] for more information.
    """

    @staticmethod
    def from_config(config: TransformersPretrainedConfig) -> "TensorRTConfig":
        # Retrieve the quantization from the transformers config (if provided)
        qmode, qconfig = TensorRTConfig.get_quantization_config(config)

        trt_config = LlamaConfig(
            architecture=config.architectures[0],
            dtype=dtype_to_str(config.torch_dtype),
            logits_dtype="float32",
            vocab_size=config.vocab_size,
            max_position_embeddings=config.max_position_embeddings,
            hidden_size=config.hidden_size,
            num_hidden_layers=config.num_hidden_layers,
            num_attention_heads=config.num_attention_heads,
            num_key_value_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
            hidden_act=config.hidden_act,
            intermediate_size=config.intermediate_size,
            norm_epsilon=config.rms_norm_eps,
            position_embedding_type="rope_gpt_neox",
            world_size=1,
            tp_size=1,
            pp_size=1,
            quant_mode=qmode,
            quant_kwargs=qconfig.to_dict(),
            use_prompt_tuning=False,
            use_parallel_embedding=False,
            embedding_sharding_dim=0,
            share_embedding_table=False,
            max_lora_rank=64,
            head_size=config.hidden_size / config.num_attention_heads,
        )

        trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8)

        return trt_config

TensorRTConfig code class

class TensorRTConfig(ABC, TensorRTPretrainedConfig):
    @staticmethod
    def get_quantization_config(
        config: PretrainedConfig,
    ) -> (QuantMode, QuantizationConfig):
        if hasattr(config, "quantization_config"):
            qconfig = config.quantization_config
            num_bits = qconfig.num_bits
            group_size = qconfig.group_size
            mode, quant_method = convert_quant_method_to_trt(
                qconfig.quant_method, num_bits
            )
            has_zero_point = qconfig.get("zero_point", False)
            exclude_modules = qconfig.get("module_to_not_convert", [])

            return mode, QuantizationConfig(
                quantization_algo=quant_method,
                kv_cache_quant_algo=None,
                group_size=group_size,
                has_zero_point=has_zero_point,
                exclude_modules=exclude_modules,
            )
        else:
            return QuantMode.from_description(), QuantizationConfig(None, None, None)

is based on PretrainedConfig from tensorrt_llm:

from tensorrt_llm.models import PretrainedConfig as TensorRTPretrainedConfig

However PretrainedConfig requires quantization arg, code from tensorrt_llm repo:

class PretrainedConfig:

    def __init__(self,
                 architecture: str,
                 dtype: str,
                 logits_dtype: str,
                 vocab_size: int,
                 max_position_embeddings: int,
                 hidden_size: int,
                 num_hidden_layers: int,
                 num_attention_heads: int,
                 num_key_value_heads: int,
                 hidden_act: str,
                 intermediate_size: int,
                 norm_epsilon: float,
                 position_embedding_type: str,
                 world_size: int,
                 tp_size: int,
                 pp_size: int,
                 quantization: Union[QuantizationConfig, dict],
                 use_prompt_tuning: bool = False,
                 use_parallel_embedding: bool = False,
                 embedding_sharding_dim: int = 0,
                 share_embedding_table: bool = False,
                 max_lora_rank: int = 64,
                 head_size: int = None,
                 **kwargs):