erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Please provide support for LLama3 or provide example on how to serve it using Easydel #144

Closed jchauhan closed 2 months ago

jchauhan commented 2 months ago

with similar configurations as llama2, we are getting some garbage responses from llama3

s[/S]

The above code will output:
Hello

The `INST` tag is used to indicate that the text should be displayed in an instance of the `Hello` class. The `S` tag is used to indicate that the text should be displayed in a sentence.
              [INST]  [INST]  [INST]  [INST]  will be the  [INST]  [INST]  and [INST]  and [INST]  and [INST]  and [INST]  C
```[INST]  and [INST]  and the  and [INST
``
[INST]  and the  and the  and the  and the  and the `s
``
``
[S] [INST]s[/S] and the [S][[INST]s[INST[/S][
``s[INST[/S
``
INST[/S]s[INST[/S][S][S][S][s[/S][s][s[/s[/INST[/S][s[/S][[INST[INST[s`<``s[INST[/S][S][S][S[INST[INST[/S][S][S][S][S
S
jchauhan commented 2 months ago

@erfanzar Will you suggest some solution? A bit urgent. Thanks

The real issue is that in order to run llama3, eos_token_id need to be an array of integers.

        eos_token_id (`Union[int, List[int]]`, *optional*):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.

As per hugging face examples

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
erfanzar commented 2 months ago

@jchauhan hi, actually you can use JAXSrever since that's faster and safer option and support dynamic prompt template

erfanzar commented 2 months ago

here's example of using llama3 model on Kaggle GPU T4x2

server_config = JAXServerConfig(
    max_sequence_length = 3072,
    max_new_tokens = 2048,
    max_compile_tokens = 512,
    pre_compile=False,
    eos_token_id=128009,
    temperature=0.3,
    top_p=0.95,
    top_k=10
)
server = JAXServer.from_torch_pretrained(
    pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B-Instruct",
    server_config=server_config,
    sharding_axis_dims=(1, 1, 1, -1),
    model_config_kwargs=dict(
        gradient_checkpointing="",
        use_scan_mlp=False,
        shard_attention_computation=False,
        use_sharded_kv_caching=True,
        attn_mechanism="local_ring"
    ),
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    auto_shard_params=True,
    load_in_8bit=True,
    input_shape=(1,2048),
    torch_dtype=torch.float16,
    device_map="cpu" # this one will be passed to transformers.AutoModelForCausalLM
)

prompt = server.format_chat(
    prompt="write a poem about stars",
    history=[],
    system="",
)

pl = 0
for response, used_tokens in server.sample(prompt):
    print(response[pl:], end="")
    pl = len(response)

# Here's a poem about stars:

# The stars shine bright in the midnight sky,
# A celestial show, beyond the eye,
# Their twinkling light, a beacon in the night,
# Guiding us through the darkness, a guiding light.

# Their beauty is a wonder, a celestial display,
# A reminder of the magic, that's always in.swing,
# A reminder of the mystery, of the stars up high,
# A reminder of the wonder, of the stars that twinkle bright,
# A reminder of the magic, of the stars that shine with a gentle light.

# So let us cherish the stars, and the magic that they bring to our lives.
jchauhan commented 2 months ago

I wanted to run it on TPU v4 instance. Finally I could do it. I had to change the prompter.

Thanks. It was easy.