erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

AMD Hardware Support #88

Closed ThePerfectComputer closed 5 months ago

ThePerfectComputer commented 5 months ago

Not exactly a bug, but I'm about to try to get EasyDel to work on some AMD GPU servers I've got, and might need some help. Would it be possible to pay for support to get EasyDel working on these servers?

ThePerfectComputer commented 5 months ago

I can provide ssh access.

clintg6 commented 5 months ago

I've got EasyDel installed on an AMD GPU cluster you need to use the latest rocm jax docker builds (rocm 6, jax 4.23) https://hub.docker.com/r/rocm/jax-build/. But heads up multiple GPU support doesn't work for most models (tested Mixtral, Falcon, MPT) and all of them failed. So far EasyDel has only worked for Llama2 for inference.

erfanzar commented 5 months ago

Falcon models runs fine but there are some problems with new flash attention and splash attention and those problems will be fixed in next 24 hours.

erfanzar commented 5 months ago

I've got EasyDel installed on an AMD GPU cluster you need to use the latest rocm jax docker builds (rocm 6, jax 4.23) https://hub.docker.com/r/rocm/jax-build/. But heads up multiple GPU support doesn't work for most models (tested Mixtral, Falcon, MPT) and all of them failed. So far EasyDel has only worked for Llama2 for inference.

thank you for telling me that Falcon models didn't work I forgot to check their attention mec after updating project structures

Now Falcon models work and have only -7.1724294e-08 error

and Mixtral models working fine as I checked them you can give the repo id for MPT model and Mixtral model to check those models too

erfanzar commented 5 months ago

Sharding Falcon for base on multiple GPUs (this works for Mixtrals too)

# Example of loading model across mutiple devices
import copy

import flax.traverse_util
# Need the latest version 0.0.43 or git+https://github.com/erfanzar/EasyDeL.git

import jax
import torch

try:
    from lib.python.EasyDel import get_modules_by_type
except ModuleNotFoundError:
    import sys
    from pathlib import Path

    cp = Path.cwd().__str__()
    sys.path.append(cp)
    from lib.python.EasyDel import get_modules_by_type

from fjformer import make_shard_and_gather_fns, match_partition_rules
from transformers import FalconForCausalLM

def main():
    torch.manual_seed(42)
    FalconConfig, FlaxFalconForCausalLM, transform_fn = get_modules_by_type("falcon")
    config = FalconConfig(
        vocab_size=1200,
        hidden_size=256,
        num_attention_heads=8,
        num_hidden_layers=2,
        gradient_checkpointing="",
        alibi=False,
    )

    torch_model = FalconForCausalLM(
        config=copy.deepcopy(config)
    )
    easy_model = FlaxFalconForCausalLM(
        config=config
    )

    partition_specs = match_partition_rules(config.get_partition_rules(True), easy_model.params_shape_tree)
    shard_fns, gather_fns = make_shard_and_gather_fns(
        partition_specs=partition_specs,
        dtype_specs=jax.numpy.float16
    )

    pytorch_dict = torch_model.state_dict()
    with config.jax_mesh():
        params = transform_fn(
            pytorch_dict,
            device=jax.devices("cpu")[0],  # This got no use but incase that some key missmatch and not getting
            # Kwargs req error we just pass that (No any params will be load on CPU for suer :) )
            shard_fns=flax.traverse_util.flatten_dict(shard_fns)
        )
    print("Sharded Successfully")

if __name__ == "__main__":
    main()
ThePerfectComputer commented 5 months ago

Cool. Can you serve up models for continuous batching to support multiple users with EasyDel? Also, which AMD GPUs are you testing on? I'm using AMD Mi50.

erfanzar commented 5 months ago

yes which serve core you want to use Torch or JAX

erfanzar commented 5 months ago

I am currently using my old PC Rx 580 as I don't have access to a massive AMD GPU. However, I am pleased to share that it is working well in every environment.

ThePerfectComputer commented 5 months ago

OK - this is good to know. I will have to try this this week. If you want to make a quick $100, happy to pay you if you get it working on my GPU server. I can give you ssh access.

ThePerfectComputer commented 5 months ago

and Mixtral models working fine as I checked them you can give the repo id for MPT model and Mixtral model to check those models too

What repo id were you referencing?

ThePerfectComputer commented 5 months ago

I'm guessing this script doesn't automatically make use of shading and flash attention? https://github.com/erfanzar/EasyDeL/blob/main/examples/serving/causal-lm/llama-2-chat.py

ThePerfectComputer commented 5 months ago

Running

python -m examples.serving.causal-lm.llama-2-chat   --pretrained_model_name_or_path="meta-llama/Llama-2-7b-chat-hf" --max_length=4096   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtype="fp16" --use_prefix_tokenizer

against fff297dc2e96703bb19a42e8ee86c940549c27c6

fails with:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/ai-data/git/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 145, in <module>
    server = Llama2Host.from_torch_pretrained(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 428, in from_torch_pretrained
    return cls.from_parameters(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 507, in from_parameters
    server.compile(verbose=verbose)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 538, in compile
    for r, a in self.process(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 838, in process
    predicted_token = self.greedy_generate(**inputs_to_gen) if greedy else self.generate(**inputs_to_gen)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 589, in greedy_generate
    return self.greedy_generate_function(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
    return common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
    ans = call(fun, *args)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 211, in greedy_generate
    predict = model.generate(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 421, in generate
    return self._greedy_search(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 640, in _greedy_search
    state = greedy_search_body_fn(state)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 616, in greedy_search_body_fn
    model_outputs = model(state.running_token, params=params, **state.model_kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 789, in __call__
    outputs = self.module.apply(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1911, in apply
    return apply(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/core/scope.py", line 1080, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 2572, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1104, in __call__
    outputs = self.model(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1001, in __call__
    outputs = self.layers(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 893, in __call__
    layer_outputs = block(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 558, in __call__
    attn_outputs = self.self_attn(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/transforms.py", line 353, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/core/lift.py", line 270, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
    return rematted(variable_groups, rng_groups, *dyn_args)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py", line 285, in fun_remat
    jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py", line 375, in _trace_to_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
    y = fn(scope, *args)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/transforms.py", line 345, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 390, in __call__
    attentions = self.attention_performer.__call__(
  File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 147, in __call__
    assert key_states.shape == (
AssertionError: 
query_states, key_states, value_states and bias shapes must be like
query_states Shape : [batch_size, num_attention_heads(32), q_seq_len,  head_dims(128)]
key_states   Shape : [batch_size, num_attention_heads(32), kv_seq_len, head_dims(128)]
value_states Shape : [batch_size, num_attention_heads(32), kv_seq_len, head_dims(128)]
bias         Shape : [batch_size, num_attention_heads(32), q_seq_len,  kv_seq_len]

I'm familiar with transformers and attention, so I could dig into this issue myself and see what's causing - but I'd much rather pay somebody to dig into this if that's an option.

erfanzar commented 5 months ago

It's ok you can email me at erfanzare810@gmail.com and share informations with me about project that you want to handel

ThePerfectComputer commented 5 months ago

Just sent you an e-mail from yehowshua@theperfect.computer