Open Additrien opened 7 months ago
I take the liberty to ping you @sanchit-gandhi
@ArthurZucker ?
We don't have any mixture of experts in Jax and I am not entirely sure optimised code exist for that yet! WDYT @sanchit-gandhi
Thanks @ArthurZucker for your feedback.
I'm not sure if optimized code exists yet, but here are the results I have obtained so far
from transformers import AutoTokenizer, MixtralModel, FlaxMixtralModel, MixtralConfig
model_path = "hf-internal-testing/Mixtral-tiny"
config = MixtralConfig.from_pretrained(model_path)
config.output_router_logits = True
config.output_attentions = True
config.output_hidden_states = True
config.max_position_embeddings = 32768
fx_model = FlaxMixtralModel.from_pretrained(model_path, config=config)
pt_model = MixtralModel.from_pretrained(model_path, config=config)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
text = "Hello my name is"
fx_inputs = tokenizer(text, return_tensors="jax")
pt_inputs = tokenizer(text, return_tensors="pt")
fx_outputs = fx_model(**fx_inputs)
pt_outputs = pt_model(**pt_inputs)
print("Flax Last Hidden State: ", fx_outputs.last_hidden_state)
print("Pytorch Last Hidden State: ", pt_outputs.last_hidden_state)
print("Flax attentions: ", fx_outputs.attentions)
print("Pytorch attentions: ", pt_outputs.attentions)
print("Flax output_router_logits: ", fx_outputs.router_logits)
print("Pytorch output_router_logits: ", pt_outputs.router_logits)
Flax Last Hidden State: [[[-1.491014 1.2440004 2.347837 ... 1.465887 -1.9995561
2.3321059 ]
[-1.0622324 0.62142754 1.8054494 ... -0.5777101 -0.43433675
1.6081295 ]
[-0.7068592 -0.43685007 -2.091109 ... -1.5816478 0.7205353
2.4112375 ]
[-0.3168698 1.1649094 1.0265665 ... 0.8728067 -0.7838142
-1.1911072 ]
[ 0.83987314 -1.8839872 -0.8021642 ... 0.36293986 -0.21989252
1.9756591 ]]]
Pytorch Last Hidden State: tensor([[[-1.4910, 1.2440, 2.3478, ..., 1.4659, -1.9996, 2.3321],
[-1.0622, 0.6214, 1.8054, ..., -0.5777, -0.4343, 1.6081],
[-0.7069, -0.4369, -2.0911, ..., -1.5816, 0.7205, 2.4112],
[-0.3169, 1.1649, 1.0266, ..., 0.8728, -0.7838, -1.1911],
[ 0.8399, -1.8840, -0.8022, ..., 0.3629, -0.2199, 1.9757]]],
grad_fn=<MulBackward0>)
Flax attentions: (Array([[[[1. , 0. , 0. , 0. , 0. ],
[0.51858807, 0.48141193, 0. , 0. , 0. ],
[0.30058202, 0.42184675, 0.2775713 , 0. , 0. ],
[0.35005265, 0.18267702, 0.2365187 , 0.23075159, 0. ],
[0.24478972, 0.13809378, 0.17894691, 0.18818794, 0.24998164]],
[[1. , 0. , 0. , 0. , 0. ],
[0.31261998, 0.68738 , 0. , 0. , 0. ],
[0.43397075, 0.3420607 , 0.22396858, 0. , 0. ],
[0.27358305, 0.23000337, 0.14827491, 0.34813863, 0. ],
[0.16551231, 0.14784649, 0.13323058, 0.25625977, 0.2971509 ]],
...
[[1. , 0. , 0. , 0. , 0. ],
[0.42223635, 0.5777636 , 0. , 0. , 0. ],
[0.30812478, 0.38982448, 0.3020508 , 0. , 0. ],
[0.1543934 , 0.36634043, 0.24219993, 0.23706625, 0. ],
[0.23449035, 0.17589284, 0.13383494, 0.1932569 , 0.262525 ]],
[[1. , 0. , 0. , 0. , 0. ],
[0.4659806 , 0.5340194 , 0. , 0. , 0. ],
[0.27027905, 0.37045106, 0.35926986, 0. , 0. ],
[0.24646765, 0.23269121, 0.21503593, 0.3058052 , 0. ],
[0.1953705 , 0.16371022, 0.19928072, 0.1873459 , 0.25429267]]]], dtype=float32))
Pytorch attentions: (tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5186, 0.4814, 0.0000, 0.0000, 0.0000],
[0.3006, 0.4218, 0.2776, 0.0000, 0.0000],
[0.3501, 0.1827, 0.2365, 0.2308, 0.0000],
[0.2448, 0.1381, 0.1789, 0.1882, 0.2500]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3126, 0.6874, 0.0000, 0.0000, 0.0000],
[0.4340, 0.3421, 0.2240, 0.0000, 0.0000],
[0.2736, 0.2300, 0.1483, 0.3481, 0.0000],
[0.1655, 0.1478, 0.1332, 0.2563, 0.2972]],
...
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4222, 0.5778, 0.0000, 0.0000, 0.0000],
[0.3081, 0.3898, 0.3021, 0.0000, 0.0000],
[0.1544, 0.3663, 0.2422, 0.2371, 0.0000],
[0.2345, 0.1759, 0.1338, 0.1933, 0.2625]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4660, 0.5340, 0.0000, 0.0000, 0.0000],
[0.2703, 0.3705, 0.3593, 0.0000, 0.0000],
[0.2465, 0.2327, 0.2150, 0.3058, 0.0000],
[0.1954, 0.1637, 0.1993, 0.1873, 0.2543]]]],
grad_fn=<SoftmaxBackward0>))
Flax output_router_logits: (Array([[ 0.05653389, 0.30332795, 0.07461646, 0.5048188 , -0.12729573,
0.51455855, 0.37247908, 0.73305213],
[-1.0776594 , -0.07945219, -0.15837692, 0.01201573, -0.30832142,
-0.640889 , -0.5488462 , -0.00608852],
...
[-0.6823018 , 0.2800284 , 0.48748446, 0.46143585, -0.2073719 ,
0.19337937, -2.0465307 , 0.3276461 ],
[ 0.27823162, -0.10238633, -0.5455977 , -0.18515491, -0.69956577,
0.853526 , 0.85062325, -1.0652022 ]], dtype=float32))
Pytorch output_router_logits: (tensor([[ 0.0565, 0.3033, 0.0746, 0.5048, -0.1273, 0.5146, 0.3725, 0.7331],
[-1.0777, -0.0795, -0.1584, 0.0120, -0.3083, -0.6409, -0.5488, -0.0061],
...
[-0.6823, 0.2800, 0.4875, 0.4614, -0.2074, 0.1934, -2.0465, 0.3276],
[ 0.2782, -0.1024, -0.5456, -0.1852, -0.6996, 0.8535, 0.8506, -1.0652]],
grad_fn=<MmBackward0>))
Flax is most interesting for speed, as long as we keep the RAM usage bounded by a 7B and get the same speedups this should be good That is what I mean by results 😉 Feel free to open a PR and work on this, we just won't be able to provide a lot of help for now!
Feature request
I would like to implement the Mixtral model in Flax
Motivation
I am in the process of learning Flax and I have almost finished the model conversion to FLAX.
Your contribution
I could submit a PR with the model implementation