erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
192 stars 23 forks source link

Llama RoPE bug #5

Closed HeegyuKim closed 1 year ago

HeegyuKim commented 1 year ago

In your code, default rotary type is 'lm2' but generation output is strange.

model is meta-llama/Llama-2-7b-hf

rope type = 'lm2'

<s>Hi, how are you doing? 🙂�\nI'm from South Korea, 😁😁��\nHi,\nI am a ��̅�\nI am a friendly and I'm new to you?\nI'm looking

rope type = 'complex'

<s>Hi, how are you doing? 🙂\nMy name is Jasmine and I am the creator of this blog.\nMy name is Jasmine, I’m 23319 years old, I’m from Germany.Hi, my name is Jasmine

This is my test code. I changed parameter names of llama model to use huggingface from_pt=True When I used converted weights using converter.py, generation output was also strange.

from EasyDel.modules.llama.modelling_llama_flax import (
    FlaxLlamaForCausalLM
)
from transformers import LlamaTokenizer, LlamaForCausalLM
import jax
import torch

tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

device = jax.devices('cpu')[0]
with jax.default_device(device):
    print("loading")
    model = FlaxLlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", from_pt=True)

    print("generating")
    ids = tokenizer(["Hi, how are you doing? "], return_tensors="np")
    outs = model.generate(
        ids["input_ids"],
        do_sample=True,
        max_length=64,
        params={"params": model.params}
    )

    print(outs)
    print(tokenizer.batch_decode(outs.sequences))
erfanzar commented 1 year ago

thanks for the report ill change the llama rotary type default to complex in version 0.0.20 If you wonder why are there three types of rotary embeddings in easydel, I have to tell you that in the case of using flash attention and training the lm2 and open are faster than complex for rotary embedding but complex if slower and more accurate and I'm still trying to find another way for that

and if you want to have better using or hosting and faster generation I recommend you to use JAXServer if you have any other problems please let me know

erfanzar commented 1 year ago

I added from_pt to the llama models you don't need to change that yourself anymore :) read NOTE.md if you do like to see that