erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Output of Tiny Llama using Easydel vs hugging face transformer api differs #86

Closed jchauhan closed 5 months ago

jchauhan commented 5 months ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce

Output from Hugging Face Transformer APIs on local env

<|system|>
You are an oracle who knows the anwers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>

<|system|>
You are an oracle who knows the anwers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>
The number of stars in the universe is estimated to be in the millions or billions. However, the precise number is still being studied and calculated. It is believed that the current number of stars in the universe is at least 100 billion, with a range of 10 to 100 trillion. The number of stars and galaxies is increasing with time, indicating that the universe is expanding.

Output from tinyllama using EasyDel

command to serve the tinyllama

run examples/serving/causal-lm/tinyllama-2-chat.py --pretrained_model_name_or_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
   ...:  --max_length=2048   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtyp
   ...: e="fp16" --use_prefix_tokenizer --mesh_axes_shape 1 -1 1 1
llama
/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/gradio-4.12.0-py3.10.egg/gradio/components/base.py:182: UserWarning: show_label has no effect when container is False.
  warnings.warn("show_label has no effect when container is False.")
Sharding Params: 100%|███████████████████████████████████████████████████████████████████████| 201/201 [00:25<00:00,  7.74it/s]
Compiling Model Forwards Greedy/Non-Greedy(Generate)
Compiling Greedy Functions
Compiling Non-Greedy(Generate) Functions
Launching App ...
/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/gradio-4.12.0-py3.10.egg/gradio/components/base.py:182: UserWarning: show_label has no effect when container is False.
  warnings.warn("show_label has no effect when container is False.")
Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://d74594468abe170a71.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
<IPython.core.display.HTML object>
Launching Server APIS (Fire) ...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
    - Avoid using `tokenizers` before the fork if possible
    - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

INFO:     Started server process [975861]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:2059 (Press CTRL+C to quit)

Output

image

the prompt was

<|system|>
You are an oracle who knows the answers of everything. There should not any uncertainity in your answers</s>
<|user|>
how many stars there in the universe?
</s>
<|assistant|>
erfanzar commented 5 months ago

Hi Can i have access to the code ? There are many tricks you have to use for partitioning and etc stuff in JAX and for the first step i recommend you to change max compile tokens to 128

jchauhan commented 5 months ago

here is the code. Just changed a little from llama chat code that was provided

import typing

import termcolor

import EasyDel
import jax.lax
from EasyDel.serve import JAXServer, JAXServerConfig
from fjformer.checkpoint import get_dtype
from transformers import AutoTokenizer
import argparse

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant in the rule of a operator. Always answer " \
                        "as helpfully as possible, while being safe.  Your answers should not" \
                        " include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
                        "illegal content. Please ensure that your responses are socially unbiased " \
                        "and positive in nature.\nIf a question does not make any sense, or is not " \
                        "factually coherent, explain why instead of answering something not correct. If " \
                        "you don't know the answer to a question, please don't share false information. " \
                        "and some time you will receive and extra data between " \
                        "tag of [EXTRA-DATA] and [/EXTRA-DATA] and you have to answer based on that extra data if you" \
                        "received one"

def get_prompt_llama2_format(message: str, chat_history,
                             system_prompt: str) -> str:
    texts = [f'<|system|>\n{system_prompt}</s>\n<|user|>\n']
    texts.append(f'{message}\n</s>\n<|assistant|>\n')
    print("".join(texts))
    return "".join(texts)

class TinyLlama2Host(JAXServer):
    def __init__(self, config=None):
        super().__init__(config=config)

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return get_prompt_llama2_format(instruction, [], system)

    @staticmethod
    def format_chat(history: typing.List[str], prompt: str, system: typing.Union[str, None]) -> str:
        return get_prompt_llama2_format(prompt, history, system)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Argument parser for Llama2.")
    parser.add_argument(
        '--pretrained_model_name_or_path',
        default='meta-llama/Llama-2-7b-chat-hf',
        help='HuggingFace Repo to load model From'
    )
    parser.add_argument(
        "--contains_auto_format",
        default=False,
        action="store_true",
        help="Whether the input text contains auto-format tokens.",
    )
    parser.add_argument(
        "--max_length",
        default=4096,
        type=int,
        help="The maximum length of the input text.",
    )
    parser.add_argument(
        "--max_new_tokens",
        default=2048,
        type=int,
        help="The maximum number of new tokens to generate.",
    )
    parser.add_argument(
        "--max_compile_tokens",
        default=32,
        type=int,
        help="The maximum number of tokens to generate per stream.",
    )
    parser.add_argument(
        "--temperature",
        default=0.6,
        type=float,
        help="The temperature of the sampling distribution.",
    )
    parser.add_argument(
        "--top_p",
        default=0.95,
        type=float,
        help="The top-p probability cutoff for the sampling distribution.",
    )
    parser.add_argument(
        "--top_k",
        default=50,
        type=int,
        help="The top-k number of tokens to keep for the sampling distribution.",
    )
    parser.add_argument(
        "--logging",
        default=False,
        action="store_true",
        help="Whether to log the generation process.",
    )
    parser.add_argument(
        "--mesh_axes_names",
        default=["dp", "fsdp", "tp", "sp"],
        nargs="+",
        help="The names of the mesh axes.",
    )
    parser.add_argument(
        "--mesh_axes_shape",
        default=[1, -1, 1, 1],
        nargs="+",
        type=int,
        help="The shapes of the mesh axes.",
    )
    parser.add_argument(
        "--dtype",
        default="fp16",
        help="The data type to use for the generation.",
    )
    parser.add_argument(
        "--use_prefix_tokenizer",
        default=False,
        action="store_true",
        help="Whether to use a prefix tokenizer.",
    )
    args = parser.parse_args()
    configs = JAXServerConfig(
        contains_auto_format=args.contains_auto_format,
        max_length=args.max_length,
        max_new_tokens=args.max_new_tokens,
        max_compile_tokens=args.max_compile_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        logging=args.logging,
        mesh_axes_names=args.mesh_axes_names,
        mesh_axes_shape=args.mesh_axes_shape,
        dtype=args.dtype,
        use_prefix_tokenizer=args.use_prefix_tokenizer
    )

    server = TinyLlama2Host.from_torch_pretrained(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        server_config=configs
    )
    try:
        termcolor.cprint(
            'Launching App ...',
            color="cyan",
            force_color=True
        )
        server.gradio_inference().launch(share=True)
        termcolor.cprint(
            'Launching Server APIS (Fire) ...',
            color="cyan",
            force_color=True
        )
        server.fire()
    except KeyboardInterrupt:
        print('Exiting ...')
        server.end()
        exit(0)
jchauhan commented 5 months ago

It worked for the 128 param value that you suggested. How did you get this number?

erfanzar commented 5 months ago

that's a bug inside the Transformers hugging face library that they pass tokens -2 if the length of the passed token are less or equal to 32 / 64 so I just suggested you try 128 (this might happened for a wide range of model and it's not being fully processed in FJFormer or EasyDeL )