huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.9k stars 26.51k forks source link

Add Mixtral Model to Flax #29319

Open Additrien opened 7 months ago

Additrien commented 7 months ago

Feature request

I would like to implement the Mixtral model in Flax

Motivation

I am in the process of learning Flax and I have almost finished the model conversion to FLAX.

Your contribution

I could submit a PR with the model implementation

Additrien commented 7 months ago

I take the liberty to ping you @sanchit-gandhi

Additrien commented 7 months ago

@ArthurZucker ?

ArthurZucker commented 7 months ago

We don't have any mixture of experts in Jax and I am not entirely sure optimised code exist for that yet! WDYT @sanchit-gandhi

Additrien commented 7 months ago

Thanks @ArthurZucker for your feedback.

I'm not sure if optimized code exists yet, but here are the results I have obtained so far

    from transformers import AutoTokenizer, MixtralModel, FlaxMixtralModel, MixtralConfig

    model_path  = "hf-internal-testing/Mixtral-tiny"

    config = MixtralConfig.from_pretrained(model_path)
    config.output_router_logits = True
    config.output_attentions = True
    config.output_hidden_states = True
    config.max_position_embeddings = 32768

    fx_model = FlaxMixtralModel.from_pretrained(model_path, config=config)
    pt_model = MixtralModel.from_pretrained(model_path, config=config)

    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
    text = "Hello my name is"
    fx_inputs = tokenizer(text, return_tensors="jax")
    pt_inputs = tokenizer(text, return_tensors="pt")

    fx_outputs = fx_model(**fx_inputs)
    pt_outputs = pt_model(**pt_inputs)

    print("Flax Last Hidden State: ", fx_outputs.last_hidden_state)
    print("Pytorch Last Hidden State: ", pt_outputs.last_hidden_state)

    print("Flax attentions: ", fx_outputs.attentions)
    print("Pytorch attentions: ", pt_outputs.attentions)

    print("Flax output_router_logits: ", fx_outputs.router_logits)
    print("Pytorch output_router_logits: ", pt_outputs.router_logits)
  Flax Last Hidden State:  [[[-1.491014    1.2440004   2.347837   ...  1.465887   -1.9995561
      2.3321059 ]
    [-1.0622324   0.62142754  1.8054494  ... -0.5777101  -0.43433675
      1.6081295 ]
    [-0.7068592  -0.43685007 -2.091109   ... -1.5816478   0.7205353
      2.4112375 ]
    [-0.3168698   1.1649094   1.0265665  ...  0.8728067  -0.7838142
    -1.1911072 ]
    [ 0.83987314 -1.8839872  -0.8021642  ...  0.36293986 -0.21989252
      1.9756591 ]]]

  Pytorch Last Hidden State:  tensor([[[-1.4910,  1.2440,  2.3478,  ...,  1.4659, -1.9996,  2.3321],
          [-1.0622,  0.6214,  1.8054,  ..., -0.5777, -0.4343,  1.6081],
          [-0.7069, -0.4369, -2.0911,  ..., -1.5816,  0.7205,  2.4112],
          [-0.3169,  1.1649,  1.0266,  ...,  0.8728, -0.7838, -1.1911],
          [ 0.8399, -1.8840, -0.8022,  ...,  0.3629, -0.2199,  1.9757]]],
        grad_fn=<MulBackward0>)

  Flax attentions:  (Array([[[[1.        , 0.        , 0.        , 0.        , 0.        ],
          [0.51858807, 0.48141193, 0.        , 0.        , 0.        ],
          [0.30058202, 0.42184675, 0.2775713 , 0.        , 0.        ],
          [0.35005265, 0.18267702, 0.2365187 , 0.23075159, 0.        ],
          [0.24478972, 0.13809378, 0.17894691, 0.18818794, 0.24998164]],

          [[1.        , 0.        , 0.        , 0.        , 0.        ],
          [0.31261998, 0.68738   , 0.        , 0.        , 0.        ],
          [0.43397075, 0.3420607 , 0.22396858, 0.        , 0.        ],
          [0.27358305, 0.23000337, 0.14827491, 0.34813863, 0.        ],
          [0.16551231, 0.14784649, 0.13323058, 0.25625977, 0.2971509 ]],
    ...
          [[1.        , 0.        , 0.        , 0.        , 0.        ],
          [0.42223635, 0.5777636 , 0.        , 0.        , 0.        ],
          [0.30812478, 0.38982448, 0.3020508 , 0.        , 0.        ],
          [0.1543934 , 0.36634043, 0.24219993, 0.23706625, 0.        ],
          [0.23449035, 0.17589284, 0.13383494, 0.1932569 , 0.262525  ]],

          [[1.        , 0.        , 0.        , 0.        , 0.        ],
          [0.4659806 , 0.5340194 , 0.        , 0.        , 0.        ],
          [0.27027905, 0.37045106, 0.35926986, 0.        , 0.        ],
          [0.24646765, 0.23269121, 0.21503593, 0.3058052 , 0.        ],
          [0.1953705 , 0.16371022, 0.19928072, 0.1873459 , 0.25429267]]]],      dtype=float32))

  Pytorch attentions:  (tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.5186, 0.4814, 0.0000, 0.0000, 0.0000],
            [0.3006, 0.4218, 0.2776, 0.0000, 0.0000],
            [0.3501, 0.1827, 0.2365, 0.2308, 0.0000],
            [0.2448, 0.1381, 0.1789, 0.1882, 0.2500]],

          [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.3126, 0.6874, 0.0000, 0.0000, 0.0000],
            [0.4340, 0.3421, 0.2240, 0.0000, 0.0000],
            [0.2736, 0.2300, 0.1483, 0.3481, 0.0000],
            [0.1655, 0.1478, 0.1332, 0.2563, 0.2972]],
    ...
          [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.4222, 0.5778, 0.0000, 0.0000, 0.0000],
            [0.3081, 0.3898, 0.3021, 0.0000, 0.0000],
            [0.1544, 0.3663, 0.2422, 0.2371, 0.0000],
            [0.2345, 0.1759, 0.1338, 0.1933, 0.2625]],

          [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.4660, 0.5340, 0.0000, 0.0000, 0.0000],
            [0.2703, 0.3705, 0.3593, 0.0000, 0.0000],
            [0.2465, 0.2327, 0.2150, 0.3058, 0.0000],
            [0.1954, 0.1637, 0.1993, 0.1873, 0.2543]]]],
        grad_fn=<SoftmaxBackward0>))

  Flax output_router_logits:  (Array([[ 0.05653389,  0.30332795,  0.07461646,  0.5048188 , -0.12729573,
          0.51455855,  0.37247908,  0.73305213],
        [-1.0776594 , -0.07945219, -0.15837692,  0.01201573, -0.30832142,
          -0.640889  , -0.5488462 , -0.00608852],
    ...
        [-0.6823018 ,  0.2800284 ,  0.48748446,  0.46143585, -0.2073719 ,
          0.19337937, -2.0465307 ,  0.3276461 ],
        [ 0.27823162, -0.10238633, -0.5455977 , -0.18515491, -0.69956577,
          0.853526  ,  0.85062325, -1.0652022 ]], dtype=float32))

  Pytorch output_router_logits:  (tensor([[ 0.0565,  0.3033,  0.0746,  0.5048, -0.1273,  0.5146,  0.3725,  0.7331],
          [-1.0777, -0.0795, -0.1584,  0.0120, -0.3083, -0.6409, -0.5488, -0.0061],
    ...
          [-0.6823,  0.2800,  0.4875,  0.4614, -0.2074,  0.1934, -2.0465,  0.3276],
          [ 0.2782, -0.1024, -0.5456, -0.1852, -0.6996,  0.8535,  0.8506, -1.0652]],
        grad_fn=<MmBackward0>))
ArthurZucker commented 7 months ago

Flax is most interesting for speed, as long as we keep the RAM usage bounded by a 7B and get the same speedups this should be good That is what I mean by results 😉 Feel free to open a PR and work on this, we just won't be able to provide a lot of help for now!