erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
209 stars 24 forks source link

[Feature Request] Add support for tiiuae/falcon-11B #152

Closed s-smits closed 6 months ago

s-smits commented 6 months ago

The model to consider.

https://huggingface.co/tiiuae/falcon-11B

The closest model EasyDeL already supports.

tiiuae/falcon-7b tiiuae/falcon-40b

What's your difficulty of supporting the model you want?

🚀 The feature, motivation and pitch

Falcon-11B is trained on multilingual data. There is a lot of potential to serve this model where these languages are preferred. Functional, working training in fp16 would be a great addition in my opinion.

Additional context

The main architectural changes between the two configurations of the Falcon model are:

  1. New Decoder Architecture:

    • Falcon-7B has new_decoder_architecture: false, which means it uses the original or a previous version of the decoder architecture.
    • Falcon-11B specifies new_decoder_architecture: true, indicating a newer version of the decoder architecture.
  2. Number of Attention Heads:

    • Falcon-7B uses num_attention_heads: 71.
    • Falcon-11B significantly decreases this number to num_attention_heads: 32.
  3. Number of Hidden Layers:

    • Falcon-11B has num_hidden_layers: 60, which is almost double the number in Falcon-7B, which has num_hidden_layers: 32.
  4. Feedforward Network Size:

    • Falcon-11B includes details about the feedforward network with ffn_hidden_size: 16384 and ff_factor: 4, which are absent in Falcon-7B.
  5. Tied Word Embeddings:

    • Falcon-7B does not mention tie_word_embeddings, which might imply the default setting is used (could be either true or false depending on the model's standard configuration).
    • Falcon-11B explicitly states tie_word_embeddings: false.

The tokenizer has been consistent. However the architecture has been changed from:

    "model_type": "falcon",
    "architectures": [
        "FalconForCausalLM"
    ],
    "pre_weights": [
        {
            "name": "transformer.word_embeddings.weight",
            "is_embed": true
        }
    ],
    "post_weights": [
        {
            "name": "transformer.ln_f.weight"
        },
        {
            "name": "transformer.ln_f.bias"
        },
        {
            "name": "lm_head.weight",
            "is_embed": true
        }
    ],
    "num_layers_config_key": "num_hidden_layers",
    "layer_templates": {
        "weights": [
            {
                "name": "transformer.h.${layer_index}.ln_attn.bias"
            },
            {
                "name": "transformer.h.${layer_index}.ln_attn.weight"
            },
            {
                "name": "transformer.h.${layer_index}.ln_mlp.bias"
            },
            {
                "name": "transformer.h.${layer_index}.ln_mlp.weight"
            },
            {
                "name": "transformer.h.${layer_index}.mlp.dense_4h_to_h.weight"
            },
            {
                "name": "transformer.h.${layer_index}.mlp.dense_h_to_4h.weight"
            },
            {
                "name": "transformer.h.${layer_index}.self_attention.dense.weight"
            },
            {
                "name": "transformer.h.${layer_index}.self_attention.query_key_value.weight"
            }
        ]
    }
}

to

    "model_type": "falcon",
    "architectures": [
        "FalconForCausalLM"
    ],
    "pre_weights": [
        {
            "name": "transformer.word_embeddings.weight",
            "is_embed": true
        }
    ],
    "post_weights": [
        {
            "name": "transformer.ln_f.weight"
        },
        {
            "name": "transformer.ln_f.bias"
        },
        {
            "name": "lm_head.weight",
            "is_embed": true
        }
    ],
    "num_layers_config_key": "num_hidden_layers",
    "layer_templates": {
        "weights": [
            {
                "name": "transformer.h.${layer_index}.input_layernorm.bias"
            },
            {
                "name": "transformer.h.${layer_index}.input_layernorm.weight"
            },
            {
                "name": "transformer.h.${layer_index}.mlp.dense_4h_to_h.weight"
            },
            {
                "name": "transformer.h.${layer_index}.mlp.dense_h_to_4h.weight"
            },
            {
                "name": "transformer.h.${layer_index}.self_attention.dense.weight"
            },
            {
                "name": "transformer.h.${layer_index}.self_attention.query_key_value.weight"
            }
        ]
    }
}`
which means the architecture has been changed.
model-00001-of-00005.safetensors: 100%|████| 4.98G/4.98G [18:21<00:00, 4.52MB/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/local/bin/scandeval", line 8, in <module>
[rank0]:     sys.exit(benchmark())
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1157, in __call__
[rank0]:     return self.main(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1078, in main
[rank0]:     rv = self.invoke(ctx)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1434, in invoke
[rank0]:     return ctx.invoke(self.callback, **ctx.params)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 783, in invoke
[rank0]:     return __callback(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/cli.py", line 332, in benchmark
[rank0]:     benchmarker(model=models)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/benchmarker.py", line 770, in __call__
[rank0]:     return self.benchmark(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/benchmarker.py", line 593, in benchmark
[rank0]:     benchmark_output = self._benchmark_single(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/benchmarker.py", line 720, in _benchmark_single
[rank0]:     results, metadata_dict, model, tokenizer = dataset(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/benchmark_dataset.py", line 601, in __call__
[rank0]:     return self.benchmark(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/benchmark_dataset.py", line 146, in benchmark
[rank0]:     model, tokenizer = load_model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/model_loading.py", line 52, in load_model
[rank0]:     model, tokenizer = setup.load_model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/model_setups/hf.py", line 311, in load_model
[rank0]:     model = VLLMModel(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/vllm_models.py", line 132, in __init__
[rank0]:     self._model = self._initialise(vllm_kwargs=vllm_kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/scandeval/vllm_models.py", line 145, in _initialise
[rank0]:     model = LLM(**vllm_kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 123, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 292, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 23, in _init_executor
[rank0]:     self._init_non_spec_worker()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 69, in _init_non_spec_worker
[rank0]:     self.driver_worker.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 118, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 164, in load_model
[rank0]:     self.model = get_model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 224, in load_model
[rank0]:     model.load_weights(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/falcon.py", line 418, in load_weights
[rank0]:     param = params_dict[name]
[rank0]: KeyError: 'transformer.h.12.input_layernorm.weight'

Lastly, I have tried to implement it myself but my JAX knowledge is limited. The config should be close to this, if I'm not mistaken:

src/python/easydel/modules/falcon2/modelling_falcon2_flax.py

from typing import Sequence, Optional

from jax.sharding import PartitionSpec

from ..easydel_modelling_utils import EasyDeLPretrainedConfig

class Falcon2Config(EasyDeLPretrainedConfig): model_type: str = "falcon" attribute_map = { "num_hidden_layers": "num_hidden_layers", "num_attention_heads": "num_attention_heads",

"lm_head.weight": {"name": "lm_head.weight", "is_embed": True}

}

def __init__(
        self,
        vocab_size: int = 65024,
        hidden_size: int = 4096,
        num_hidden_layers: int = 71,
        num_attention_heads: int = 32,
        layer_norm_epsilon: float = 1e-5,
        initializer_range: float = 0.02,
        use_cache: bool = True,
        hidden_dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_kv_heads=8,
        alibi: bool = False,
        new_decoder_architecture: bool = True,
        multi_query: bool = False, # for now, has to be true actually
        parallel_attn: bool = True,
        bias: bool = False,
        max_position_embeddings: int = 8192,  # Updated default value
        rope_theta: float = 500042.0,
        rope_scaling=None,
        bos_token_id: int = 11,
        eos_token_id: int = 11,
        gradient_checkpointing: str = "",
        bits: Optional[int] = None,
        axis_dims: Sequence[int] = (1, -1, 1, 1),
        axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
        **kwargs,
):
    self.vocab_size = vocab_size
    n_embed = kwargs.pop("n_embed", None)
    self.hidden_size = hidden_size if n_embed is None else n_embed
    self.num_hidden_layers = num_hidden_layers
    self.num_attention_heads = num_attention_heads
    self.layer_norm_epsilon = layer_norm_epsilon
    self.initializer_range = initializer_range
    self.rope_theta = rope_theta
    self.rope_scaling = rope_scaling
    self.max_position_embeddings = max_position_embeddings
    self.use_cache = use_cache
    self.hidden_dropout = hidden_dropout
    self.attention_dropout = attention_dropout
    self.bos_token_id = bos_token_id
    self.eos_token_id = eos_token_id
    self.multi_query = multi_query
    self.alibi = alibi
    self.bias = bias
    self.gradient_checkpointing = gradient_checkpointing
    self.parallel_attn = parallel_attn
    self.num_kv_heads = num_kv_heads
    self.new_decoder_architecture = new_decoder_architecture
    self.bits = bits
    self.from_pt = False

    self._rope_scaling_validation()  # Validate rope_scaling

    super().__init__(
        axis_dims=axis_dims,
        axis_names=axis_names,
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        bits=bits,
        **kwargs
    )

@property
def head_dim(self):
    return self.hidden_size // self.num_attention_heads

@property
def rotary(self):
    return not self.alibi

@staticmethod
def get_partition_rules(fully_sharded_data_parallel: bool = False):
    return (
        ('word_embeddings/embedding', PartitionSpec("dp", ("fsdp", "sp"))),
        ('self_attention/query_key_value/(kernel)', PartitionSpec("dp", ("fsdp", "sp"))),
        ('self_attention/dense/(kernel)', PartitionSpec("dp", ("fsdp", "sp"))),
        ('mlp/dense_4h_to_h/(kernel)', PartitionSpec("dp", ("fsdp", "sp"))),
        ('mlp/dense_h_to_4h/(kernel)', PartitionSpec("dp", ("fsdp", "sp"))),
        ('transformer/input_layernorm/scale', PartitionSpec(("fsdp", "sp"))),
        ('transformer/input_layernorm/bias', PartitionSpec(("fsdp", "sp"))),
        ('.*', PartitionSpec(("fsdp", "sp")))
    ) if not fully_sharded_data_parallel else (
        ('word_embeddings/embedding', PartitionSpec(("fsdp", "sp"))),
        ('self_attention/query_key_value/(kernel|bias)', PartitionSpec(("fsdp", "sp"))),
        ('self_attention/dense/(kernel|bias)', PartitionSpec(("fsdp", "sp"))),
        ('mlp/dense_4h_to_h/(kernel|bias)', PartitionSpec(("fsdp", "sp"))),
        ('mlp/dense_h_to_4h/(kernel|bias)', PartitionSpec(("fsdp", "sp"))),
        ('transformer/input_layernorm/scale', PartitionSpec(("fsdp", "sp"))),
        ('transformer/input_layernorm/bias', PartitionSpec(("fsdp", "sp"))),
        ('.*', PartitionSpec(("fsdp", "sp")))
    )

@staticmethod
def get_mesh_names():
    return "dp", "fsdp", "tp", "sp"

def add_jax_args(self,
                 vocab_size: int = 65024,
                 hidden_size: int = 4544,
                 num_hidden_layers: int = 60,
                 num_attention_heads: int = 71,
                 layer_norm_epsilon: float = 1e-5,
                 initializer_range: float = 0.02,
                 use_cache: bool = True,
                 hidden_dropout: float = 0.0,
                 attention_dropout: float = 0.0,
                 num_kv_heads=None,
                 alibi: bool = False,
                 new_decoder_architecture: bool = False,
                 multi_query: bool = False,
                 parallel_attn: bool = True,
                 bias: bool = False,
                 max_position_embeddings: int = 8192,  # Updated default value
                 rope_theta: float = 10000.0,
                 rope_scaling=None,
                 bos_token_id: int = 11,
                 eos_token_id: int = 11,
                 gradient_checkpointing: str = "",
                 bits: Optional[int] = None,
                 **kwargs,
                 ):

    basics = dict(
        bits=bits,
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers,
        num_attention_heads=num_attention_heads,
        layer_norm_epsilon=layer_norm_epsilon,
        rope_theta=rope_theta,
        initializer_range=initializer_range,
        use_cache=use_cache,
        bos_token_id=bos_token_id,
        num_kv_heads=num_kv_heads,
        eos_token_id=eos_token_id,
        max_position_embeddings=max_position_embeddings,
        hidden_dropout=hidden_dropout,
        attention_dropout=attention_dropout,
        multi_query=multi_query,
        alibi=alibi,
        bias=bias,
        parallel_attn=parallel_attn,
        rope_scaling=rope_scaling,
        gradient_checkpointing=gradient_checkpointing,
        new_decoder_architecture=new_decoder_architecture,
        **kwargs
    )
    for key_states, value_states in basics.items():
        if not hasattr(self, key_states):
            setattr(self, key_states, value_states)

    self.from_pt = False

def _rope_scaling_validation(self):
    """
    Validate the `rope_scaling` configuration.
    """
    if self.rope_scaling is None:
        return

    if self.alibi:
        raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")

    if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
        raise ValueError(
            "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
            f"got {self.rope_scaling}"
        )
    rope_scaling_type = self.rope_scaling.get("type", None)
    rope_scaling_factor = self.rope_scaling.get("factor", None)
    if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
        raise ValueError(
            f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
        )
    if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
        raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
erfanzar commented 6 months ago

Hello and thanks for using EasyDeL

Sure ill support falcon model in next 24 hours.

s-smits commented 6 months ago

Great, thank you. My suggestion would be to start with a new model type 'falcon2' because there are quite a bit of architectural changes. Also there is some rope scaling differences, however I did not find the time to look into it deeply.

erfanzar commented 6 months ago

Yes i can do that, tomorrow falcon2 and aya will be available

erfanzar commented 6 months ago

Falcon 11B is now supported with new architecture

pip install git+https://github.com/erfanzar/EasyDeL.git -U

(flash attention is supported too in case that your not using ALIBI)