erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Output Differs from Hugging Face Transformer Result and EasyDel Results #116

Closed jchauhan closed 4 months ago

jchauhan commented 4 months ago

Describe the bug

We are running a model on TPU v3.8 instance using easydel. It worked great after your suggestion here. https://github.com/erfanzar/EasyDeL/issues/114

However, the output produced by Easydel Hosted Model is wrong

Question: Which of the following is an example of monosomy?
Options:

46,XX
47,XXX
69,XYY
45,X
Please provide your choice first and then provide explanations if possible.

The correct answer is 46,XX.

Monosomy is a condition where a person has only one copy of a particular chromosome. In the case of 46,XX, a person has only one X chromosome. This is the most common form of monosomy and is typically associated with Turner syndrome, a genetic disorder that affects females.

47,XXX is also a form of monosomy, but it is a different type of chromosome. In this case, a person has three X chromosomes instead of the typical two. This condition is known as triple X syndrome and is also a genetic disorder that affects females.

69,XYY is a form of monosomy that involves an extra Y chromosome. This condition is known as Klinefelter syndrome and is also a genetic disorder that affects males.

45,X is a form of monosomy that involves a missing X chromosome. This condition is known as Turner syndrome and is a genetic disorder that affects females.

As opposed to output produced by Hugging face hosting of the model is

Question: Which of the following is an example of monosomy?
Options:
- 46,XX
- 47,XXX
- 69,XYY
- 45,X

Please provide your choice first and then provide explanations if possible.

### Assistant Output:
The correct answer is 45,X.

Monosomy is a condition where a person has only one copy of a particular chromosome. In this case, the person has only one X chromosome, which is a form of Turner syndrome. This condition is usually caused by a missing or partially deleted X chromosome.

The other options are not examples of monosomy:

- 46,XX: This is a normal karyotype, where a person has two X chromosomes.
- 47,XXX: This is a form of trisomy, where a person has three X chromosomes.
- 69,XYY: This is a form of trisomy, where a person has three X chromosomes and an extra Y chromosome.

Above is the correct answer.

To Reproduce

Run the following code to reproduce it

``import json from typing import List, Union

from absl.app import run from absl import flags from EasyDel import JAXServer, JAXServerConfig import jax from fjformer import get_dtype from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter from EasyDel.serve.prompters.base_prompter import BasePrompter

FLAGS = flags.FLAGS flags.DEFINE_enum( "prompter_type", enum_values=("gemma", "llama", "openchat", "qwen2", "medllama"), help="Prompter to be used to prompt the model", default="medllama" ) flags.DEFINE_string( "pretrained_model_name_or_path", default="AdaptLLM/medicine-chat", help="The pretrained model path in huggingface.co/models" ) flags.DEFINE_integer( "max_compile_tokens", default=256, help="Maximum number of compiled tokens" )

flags.DEFINE_integer( "max_new_tokens_ratio", default=20, help="max new tokens ratio to be multiplied for max_compile_tokens for max_new_tokens" )

flags.DEFINE_integer( "max_sequence_length", default=2048, help="max sequence length to be used in the model" )

flags.DEFINE_enum( "dtype", enum_values=( "bf16", "fp16", "fp32" ), default="bf16", help="The data type of the model" )

flags.DEFINE_list( "sharding_axis_dims", default=[1, 1, 1, -1], help="Sharding Axis dimensions for the model" )

flags.DEFINE_bool( "use_sharded_kv_caching", default=False, help="whether to use sharded kv for Large Sequence model up to 1M" )

flags.DEFINE_bool( "scan_ring_attention", default=True, help="whether to scan ring attention for Large Sequence model up to 1M (works with attn_mechanism='ring')" )

flags.DEFINE_bool( "use_scan_mlp", default=True, help="whether to scan MLP or FFN Layers for Large Sequence model up to 1M" )

flags.DEFINE_enum( "attn_mechanism", enum_values=["normal", "flash", "ring", "splash"], default="normal", help="The attention mechanism to be used in the model" )

flags.DEFINE_integer( "block_k", default=128, help="the number of chunks for key block in attention (Works with flash, splash, ring Attention mechanism)" )

flags.DEFINE_integer( "block_q", default=128, help="the number of chunks for query block in attention (Works with flash, splash, ring Attention mechanism)" )

flags.DEFINE_bool( "share_gradio", default=True, help="whether to share gradio app" )

flags.DEFINE_string( "gradio_root_path", default="", help="Root Path to host Geadio Server" )

from abc import ABC from EasyDel.serve.prompters.base_prompter import BasePrompter from typing import List, Optional

class MedLlama2Prompter(BasePrompter, ABC): def init( self, ): user_prefix = "[INST]" assistant_prefix = "[/INST]" super().init( user_message_token=user_prefix, assistant_message_token=assistant_prefix, prompter_type="medllama", end_of_turn_token="", )

def format_history_prefix(
        self,
        history: list[list[str]],
        system_message: str,
):
    prompt = ""
    for user, assistant in history:
        prompt += f"{self.user_message_token}{user} "
        prompt += f"{self.assistant_message_token}{assistant} "
    print("format_history_prefix", prompt)
    return prompt

def format_message(
        self,
        prompt: str,
        history: list[list[str]],
        system_message: Optional[str],
        prefix: Optional[str]
) -> str:

    dialogs = prefix if prefix is not None else ""

    for user, assistant in history:
        dialogs += f"{self.user_message_token}{user}"
        dialogs += f"{self.assistant_message_token}{assistant}"

    dialogs += f"{self.user_message_token}{prompt}"
    dialogs += self.assistant_message_token
    print("format_message", dialogs)
    return dialogs

import transformers from typing import Optional, Mapping, Callable, Dict, Any from jax.sharding import Mesh, PartitionSpec from typing import Union, Sequence, List from EasyDel.modules.auto_easydel_model import AutoEasyDelModelForCausalLM

def main(argv): server_config = JAXServerConfig( max_sequence_length=FLAGS.max_sequence_length, max_compile_tokens=FLAGS.max_compile_tokens, max_new_tokens=FLAGS.max_compile_tokens * FLAGS.max_new_tokens_ratio, dtype=FLAGS.dtype ) prompters = { "gemma": GemmaPrompter(), "llama": Llama2Prompter(), "openchat": OpenChatPrompter(), "qwen2": Qwen2Prompter(), "medllama": MedLlama2Prompter() } prompter: BasePrompter = prompters[FLAGS.prompter_type]

FLAGS.sharding_axis_dims = tuple([int(s) for s in FLAGS.sharding_axis_dims])

class JAXServerMedLlama(JAXServer):
    @staticmethod
    def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
        return prompter.format_message(
            history=[],
            prompt=prompt,
            system_message=system,
            prefix=None
        )

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return prompter.format_message(
            prefix=None,
            system_message=system,
            prompt=instruction,
            history=[]
        )
    @classmethod
    def from_torch_pretrained(
            cls,
            server_config: JAXServerConfig,
            pretrained_model_name_or_path: str,
            device=jax.devices('cpu')[0],
            dtype: jax.numpy.dtype = jax.numpy.float32,
            param_dtype: jax.numpy.dtype = jax.numpy.float32,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
            sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
            sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
            query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            use_shard_map: bool = False,
            input_shape: Sequence[int] = (1, 1),
            shard_fns: Optional[Mapping[tuple, Callable]] = None,
            backend: Optional[str] = None,
            add_params_field: bool = True,
            do_memory_log: bool = False,
            model_config_kwargs: Optional[Mapping[str, Any]] = None,
            verbose: bool = True,
            **kwargs
    ):

        model, params = AutoEasyDelModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            device=device,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            sharding_axis_names=sharding_axis_names,
            sharding_axis_dims=sharding_axis_dims,
            query_partition_spec=query_partition_spec,
            attention_partition_spec=attention_partition_spec,
            value_partition_spec=value_partition_spec,
            key_partition_spec=key_partition_spec,
            bias_partition_spec=bias_partition_spec,
            use_shard_map=use_shard_map,
            shard_fns=shard_fns,
            input_shape=input_shape,
            backend=backend,
            config_kwargs=model_config_kwargs,
            **kwargs
        )

        rule = (
        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),

)

        model.config.get_partition_rules = lambda _:rule # this will set model config partition to the given custom partition rule

        return cls.from_parameters(
            model=model,
            config_model=model.config,
            tokenizer=transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path),
            params=params,
            server_config=server_config,
            verbose=verbose,
            do_memory_log=do_memory_log,
            add_params_field=add_params_field
        )

server = JAXServerMedLlama.from_torch_pretrained(
    server_config=server_config,
    pretrained_model_name_or_path=FLAGS.pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    dtype=get_dtype(dtype=FLAGS.dtype),
    param_dtype=get_dtype(dtype=FLAGS.dtype),
    precision=jax.lax.Precision("fastest"),
    sharding_axis_dims=FLAGS.sharding_axis_dims,
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    input_shape=(1, server_config.max_sequence_length),
    model_config_kwargs=dict(
        fully_sharded_data_parallel=True,
        attn_mechanism=FLAGS.attn_mechanism,
        scan_mlp_chunk_size=FLAGS.max_compile_tokens,
        use_scan_mlp=FLAGS.use_scan_mlp,
        scan_ring_attention=FLAGS.scan_ring_attention,
        block_k=FLAGS.block_k,
        block_q=FLAGS.block_q,
        use_sharded_kv_caching=FLAGS.use_sharded_kv_caching
    )
)

server.gradio_inference().launch(
    root_path=FLAGS.gradio_root_path,
    server_name="0.0.0.0",
    server_port=7680,
    show_api=True,
    share=FLAGS.share_gradio
)

if name == "main": run(main)



Is anything missing?
jchauhan commented 4 months ago

There were some warnings at the start of the execution of the script

site-packages/transformers/generation/configuration_utils.py:406: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/home/neo/research/belie/.venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:411: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
erfanzar commented 4 months ago

Set top k and top p for both models to 1 and sample to false then both models will return same result it's being happen cause you are using sample in huggingface or non greedy in easydel I could recommend you to read about these parameters.

jchauhan commented 4 months ago

I have tried multiple options, in all the cases, Easydel/model generate wrong result. However, hugging face code gives the correct result 1 out 10 times.

The answer should be 45,X

here is a hugging face code. It is worth checking again. Please check my code for easydel. Do you see anything fishy?


!pip install bitsandbytes bitsandbytes datasets accelerate loralib
!pip install transformers@git+https://github.com/huggingface/transformers.git@main
!pip install peft@git+https://github.com/huggingface/peft.git
!pip install datasets
!pip install datasets --upgrade
!pip install evaluate

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("AdaptLLM/medicine-chat", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("AdaptLLM/medicine-chat")

# Put your input here:
user_input = '''Question: Which of the following is an example of monosomy?
Options:
- 46,XX
- 47,XXX
- 69,XYY
- 45,X

Please provide your choice first and then provide explanations if possible.'''

# Apply the prompt template and system prompt of LLaMA-2-Chat demo for chat models (NOTE: NO prompt template is required for base models!)
our_system_prompt = "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" # Please do NOT change this
prompt = f"<s>[INST] <<SYS>>{our_system_prompt}<</SYS>>\n\n{user_input} [/INST]"

# # NOTE:
# # If you want to apply your own system prompt, please integrate it into the instruction part following our system prompt like this:
# your_system_prompt = "Please, answer this question faithfully."
# prompt = f"<s>[INST] <<SYS>>{our_system_prompt}<</SYS>>\n\n{your_system_prompt}\n{user_input} [/INST]"

inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
outputs = model.generate(input_ids=inputs, max_length=4096)[0]

answer_start = int(inputs.shape[-1])
pred = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True)

print(f'### User Input:\n{user_input}\n\n### Assistant Output:\n{pred}')
jchauhan commented 4 months ago

Finally, Found the issue. The issue was in slight change in the input to transformer/model and easydel/model.

Thanks a lot