Closed s-smits closed 6 months ago
Hello and thanks for using EasyDeL
Sure ill support falcon model in next 24 hours.
Great, thank you. My suggestion would be to start with a new model type 'falcon2' because there are quite a bit of architectural changes. Also there is some rope scaling differences, however I did not find the time to look into it deeply.
Yes i can do that, tomorrow falcon2 and aya will be available
Falcon 11B is now supported with new architecture
pip install git+https://github.com/erfanzar/EasyDeL.git -U
(flash attention is supported too in case that your not using ALIBI)
The model to consider.
https://huggingface.co/tiiuae/falcon-11B
The closest model EasyDeL already supports.
tiiuae/falcon-7b tiiuae/falcon-40b
What's your difficulty of supporting the model you want?
🚀 The feature, motivation and pitch
Falcon-11B is trained on multilingual data. There is a lot of potential to serve this model where these languages are preferred. Functional, working training in fp16 would be a great addition in my opinion.
Additional context
The main architectural changes between the two configurations of the Falcon model are:
New Decoder Architecture:
new_decoder_architecture: false
, which means it uses the original or a previous version of the decoder architecture.new_decoder_architecture: true
, indicating a newer version of the decoder architecture.Number of Attention Heads:
num_attention_heads: 71
.num_attention_heads: 32
.Number of Hidden Layers:
num_hidden_layers: 60
, which is almost double the number in Falcon-7B, which hasnum_hidden_layers: 32
.Feedforward Network Size:
ffn_hidden_size: 16384
andff_factor: 4
, which are absent in Falcon-7B.Tied Word Embeddings:
tie_word_embeddings: false
.The tokenizer has been consistent. However the architecture has been changed from:
to
Lastly, I have tried to implement it myself but my JAX knowledge is limited. The config should be close to this, if I'm not mistaken:
src/python/easydel/modules/falcon2/modelling_falcon2_flax.py
from typing import Sequence, Optional
from jax.sharding import PartitionSpec
from ..easydel_modelling_utils import EasyDeLPretrainedConfig
class Falcon2Config(EasyDeLPretrainedConfig): model_type: str = "falcon" attribute_map = { "num_hidden_layers": "num_hidden_layers", "num_attention_heads": "num_attention_heads",
"lm_head.weight": {"name": "lm_head.weight", "is_embed": True}