erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Transformers-like API for inference #128

Closed Froggy111 closed 3 months ago

Froggy111 commented 3 months ago

Is there a versatile transformers-like API (like model.generate()) equivalent for this? I tried JAXServer but it is quite confusing, and I couldnt get flashattention to work. Could you maybe provide some guidance? Thanks very much, appreciate it

Froggy111 commented 3 months ago

Also, how can we load quantized models (like GPTQ) onto TPUs?

erfanzar commented 3 months ago

Hello and thanks for using EasyDeL

No you can't load quantized model onto EasyDeL but 80% of LLMs from hf and PyTorch are supported

And in case of using flash attention and generate function you can tell me clearly what you need so i can create an example for you

Froggy111 commented 3 months ago

I need mistral-7b and mixtral 8x7b flash attention generate function, I have been trying with mistral but it gives error of block_q=128 has to be smaller or equals to seq_len_q=1, and am unable to find why this occurs. I am running on Google cloud TPUs. Again, thanks very much for the help, it is really appreciated

Froggy111 commented 3 months ago

Also, are there other ways to load quantised models in Jax?

Froggy111 commented 3 months ago

This is my code for reference

import EasyDel, jax, transformers

tokenizer = transformers.AutoTokenizer.from_pretrained (
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
    ["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
    ]* 4,
    return_tensors = "jax",
    pad_to_multiple_of = 128,
    padding = True,
)
print(type(input_ids))
print(input_ids)
attention_mask = input_ids.attention_mask
print(type(attention_mask))
print(attention_mask)
input_ids = input_ids.input_ids
print(type(input_ids))
print(input_ids)

model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained (
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
    device = jax.devices('cpu')[0],
    device_map = "auto",
    dtype = jax.numpy.bfloat16,
    param_dtype = jax.numpy.bfloat16,
    precision = jax.lax.Precision("fastest"),
    sharding_axis_dims = (1, -1, 1, 1),
    sharding_axis_names = ("dp", "fsdp", "tp", "sp"),
    backend = "tpu",
    input_shape = (4, 2048),
    config_kwargs = {
        "attn_mechanism": "flash",
    },
)
print(type(model))
print(model)
print(type(params))
print(params.keys())
#print(params)
# transformers.GenerationConfig()
generated_ids = model.generate (
    input_ids = input_ids,
    attention_mask = attention_mask,
    params = {"params": params},
    generation_config = transformers.GenerationConfig (
        max_new_tokens = 1024,
        eos_token_id = tokenizer.eos_token_id,
        pad_token_id = tokenizer.pad_token_id,
        bos_token_id = tokenizer.bos_token_id,
        temperature = 0.7,
        do_sample = True,
        num_beams = 1,
        top_p = 100,
        top_k = 100,
        repetition_penalty = 0.01,
    ),
    max_new_tokens = 1024,
)
print(generated_ids)
output = tokenizer.decode (
    generated_ids,
    skip_special_tokens = True,
    clean_up_tokenization_spaces = True)
print(output)
erfanzar commented 3 months ago

I know why you are getting error in generating process give me 5 hours and ill fix it.

erfanzar commented 3 months ago

can you try running that code again?

Froggy111 commented 3 months ago

Hi, it still does not work.

code:

import jax, transformers
from EasyDeL.lib.python import EasyDel
from jax.sharding import PartitionSpec
from typing import Sequence, Optional

tokenizer = transformers.AutoTokenizer.from_pretrained (
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
    ["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
    ]* 4,
    return_tensors = "jax",
    pad_to_multiple_of = 128,
    padding = True,
)
print(type(input_ids))
print(input_ids)
attention_mask = input_ids.attention_mask
print(type(attention_mask))
print(attention_mask)
input_ids = input_ids.input_ids
print(type(input_ids))
print(input_ids)

def load_model(
        pretrained_model_name_or_path: str,
        device=jax.devices('cpu')[0],  # Device to be used in order to Load Model on (Offload device)
        dtype: jax.numpy.dtype = jax.numpy.float32,
        param_dtype: jax.numpy.dtype = jax.numpy.float32,
        precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
        sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
        sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
        query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        generation_query_partition_spec = PartitionSpec(("dp", "fsdp"), "tp", None, None),
        key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
        attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        use_shard_map: bool = False,
        input_shape: Sequence[int] = (1, 1),
        backend: Optional[str] = None,
        config_kwargs: dict = None,
):
    model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        device=device,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        sharding_axis_names=sharding_axis_names,
        sharding_axis_dims=sharding_axis_dims,
        query_partition_spec=query_partition_spec,
        generation_query_partition_spec=generation_query_partition_spec,
        key_partition_spec=key_partition_spec,
        value_partition_spec=value_partition_spec,
        bias_partition_spec=bias_partition_spec,
        attention_partition_spec=attention_partition_spec,
        use_shard_map=use_shard_map,
        input_shape=input_shape,
        backend=backend,
        config_kwargs=config_kwargs
    )
    return model, params

model, params = load_model (
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
    dtype = jax.numpy.bfloat16,
    param_dtype = jax.numpy.bfloat16,
    precision = jax.lax.Precision("fastest"),
    input_shape = (4, 2048),
    config_kwargs = {
        "attn_mechanism": "flash",
    },
    backend = "tpu"
)

# model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained (
#     pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
#     device = jax.devices('cpu')[0],
#     device_map = "auto",
#     dtype = jax.numpy.bfloat16,
#     param_dtype = jax.numpy.bfloat16,
#     precision = jax.lax.Precision("fastest"),
#     sharding_axis_dims = (1, -1, 1, 1),
#     #sharding_axis_names = ("dp", "fsdp", "tp", "sp"),
#     #backend = "tpu",
#     input_shape = (4, 2048),
#     config_kwargs = {
#         "attn_mechanism": "flash",
#     },
# )
print(type(model))
print(model)
print(type(params))
print(params.keys())
#print(params)
# transformers.GenerationConfig()
generated_ids = model.generate (
    input_ids = input_ids,
    attention_mask = attention_mask,
    params = {"params": params},
    generation_config = transformers.GenerationConfig (
        max_new_tokens = 1024,
        eos_token_id = tokenizer.eos_token_id,
        pad_token_id = tokenizer.pad_token_id,
        bos_token_id = tokenizer.bos_token_id,
        temperature = 0.7,
        do_sample = True,
        num_beams = 1,
        top_p = 100,
        top_k = 100,
        repetition_penalty = 0.01,
    ),
    #max_new_tokens = 1024,
)
print(generated_ids)
output = tokenizer.decode (
    generated_ids,
    skip_special_tokens = True,
    clean_up_tokenization_spaces = True)
print(output)

error is still the same error as previously: ValueError: block_q=128 should be smaller or equal to q_seq_len=1

erfanzar commented 3 months ago

fixed. can you try again? i can run the code.

and change model_config to this

config_kwargs = {
    "attn_mechanism": "flash",
    "gradient_checkpointing": ""
}
Froggy111 commented 3 months ago

Can you send the code you used? I am still getting the same issue. Thanks.

erfanzar commented 3 months ago

Have you updated your EasyDel? Cause the argument you are getting error from is removed

erfanzar commented 3 months ago
import jax, transformers
import EasyDel
from jax.sharding import PartitionSpec
from typing import Sequence, Optional
from jax.sharding import PartitionSpec

dev_len = 6

tokenizer = transformers.AutoTokenizer.from_pretrained (
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
    ["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
    ]* dev_len,
    return_tensors = "jax",
    max_length = 512,
    padding = "max_length",
)

attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids

model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
    device = jax.devices('cpu')[0],
    device_map = "auto",
    dtype = jax.numpy.bfloat16,
    param_dtype = jax.numpy.bfloat16,
    precision = jax.lax.Precision("fastest"),
    sharding_axis_dims = (1, -1, 1, 1),

    input_shape = (1, 2048),

    query_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
    generation_query_partition_spec=PartitionSpec(("dp", "fsdp"), None, None, "tp"),
    key_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
    value_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
    bias_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "sp"),
    attention_partition_spec=PartitionSpec(("dp", "fsdp"), None,"sp", "tp"),

    config_kwargs = {
        "attn_mechanism": "flash",
        "gradient_checkpointing" : ""
    },
)

generated_ids = model.generate (
    input_ids = input_ids,
    attention_mask = attention_mask,
    params = {"params": params},
    generation_config = transformers.GenerationConfig (
        max_new_tokens = 1024,
        max_length = 512,
        eos_token_id = tokenizer.eos_token_id,
        pad_token_id = tokenizer.pad_token_id,
        bos_token_id = tokenizer.bos_token_id,
        temperature = 0.7,
        do_sample = True,
        num_beams = 1,
        top_p = 0.1,
        top_k = 2,
        repetition_penalty = 1.25,
    ),
)
print(generated_ids)
output = tokenizer.decode (
    generated_ids,
    skip_special_tokens = True,
    clean_up_tokenization_spaces = True)
print(output)
Froggy111 commented 3 months ago

Using jax[tpu]==0.4.22, and the latest commit of EasyDel on main, on a google cloud tpu-v4-8 VM: Running your code, after changing input_shape to (4, 2048) and dev_len to 4, I get the following error:

ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec(('dp', 'fsdp'), None, ('sp',), ('sp',)) has duplicate entries for `sp`

And when I change bias_partition_spec to

bias_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", None),

I get the old error

ValueError: block_q=128 should be smaller or equal to q_seq_len=1

If I don't change the input shape to 4, I get this error:

ValueError: shard_map applied to the function 'functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 4, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

* args[0] of shape float32[1,32,2048,128], where args[0] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[1] of shape float32[1,32,2048,128], where args[1] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[2] of shape float32[1,32,2048,128], where args[2] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[3] of shape bfloat16[1,32,2048,2048], where args[3] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'ab', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'sp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)' appropriately.
erfanzar commented 3 months ago

I guess the issue is now fixed, the issue was in detecting the Generation Process to change the block sizes from

is_generating = query_states.shape[1] == 1
query_sequence_partition = self.generation_query_partition_spec if is_generating else self.query_partition_spec
bias_partition_spec = self.generation_bias_partition_spec if is_generating else self.bias_partition_spec
block_q = 1 if is_generating else self.block_q
block_q_major_dkv = 1 if is_generating else self.block_q_major_dkv
block_q_dkv = 1 if is_generating else self.block_q_dkv
block_q_dq = 1 if is_generating else self.block_q_dq

to

is_generating = query_states.shape[2] == 1
query_sequence_partition = self.generation_query_partition_spec if is_generating else self.query_partition_spec
bias_partition_spec = self.generation_bias_partition_spec if is_generating else self.bias_partition_spec
block_q = 1 if is_generating else self.block_q
block_q_major_dkv = 1 if is_generating else self.block_q_major_dkv
block_q_dkv = 1 if is_generating else self.block_q_dkv
block_q_dq = 1 if is_generating else self.block_q_dq
Froggy111 commented 3 months ago

Now, I am getting this issue:

RuntimeError: Internal TPU kernel compiler error: Not implemented: Non-trivial layouts unsupported

The MLIR operation involved:
  %130 = "tpu.repeat"(%129) {dimension = 1 : i32, in_layout = [#tpu.vpad<"32,{0,0},(1,128)">], out_layout = [#tpu.vpad<"32,{0,0},(1,128)">], times = 1 : i32} : (vector<1x128xf32>) -> vector<1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

with jax[tpu]==0.4.22 and jax[tpu]==0.4.23 and 0.4.24 and above do not work as it seems some APIs were removed

erfanzar commented 3 months ago

;\ which TPU version you are using? can you upgrade jax to 0.4.25?

Froggy111 commented 3 months ago

jax-0.4.25 jaxlib-0.4.25 libtpu-nightly-0.1.dev20240224 AttributeError: 'Config' object has no attribute 'define_bool_state'

erfanzar commented 3 months ago

is the error coming from EasyDeL or FJFormer? I never created a value named define_bool_state or even use one

erfanzar commented 3 months ago

it seems like it's an environmental issue and it's both related and not related to EasyDeL or FJF at the same time, and since debugging code like this won't get us anywhere I recommend you join this discord server JAXLLM so we can communicate better.