Closed ThePerfectComputer closed 5 months ago
I can provide ssh access.
I've got EasyDel installed on an AMD GPU cluster you need to use the latest rocm jax docker builds (rocm 6, jax 4.23) https://hub.docker.com/r/rocm/jax-build/. But heads up multiple GPU support doesn't work for most models (tested Mixtral, Falcon, MPT) and all of them failed. So far EasyDel has only worked for Llama2 for inference.
Falcon models runs fine but there are some problems with new flash attention and splash attention and those problems will be fixed in next 24 hours.
I've got EasyDel installed on an AMD GPU cluster you need to use the latest rocm jax docker builds (rocm 6, jax 4.23) https://hub.docker.com/r/rocm/jax-build/. But heads up multiple GPU support doesn't work for most models (tested Mixtral, Falcon, MPT) and all of them failed. So far EasyDel has only worked for Llama2 for inference.
thank you for telling me that Falcon models didn't work I forgot to check their attention mec after updating project structures
Now Falcon models work and have only -7.1724294e-08 error
and Mixtral models working fine as I checked them you can give the repo id for MPT model and Mixtral model to check those models too
Sharding Falcon for base on multiple GPUs (this works for Mixtrals too)
# Example of loading model across mutiple devices
import copy
import flax.traverse_util
# Need the latest version 0.0.43 or git+https://github.com/erfanzar/EasyDeL.git
import jax
import torch
try:
from lib.python.EasyDel import get_modules_by_type
except ModuleNotFoundError:
import sys
from pathlib import Path
cp = Path.cwd().__str__()
sys.path.append(cp)
from lib.python.EasyDel import get_modules_by_type
from fjformer import make_shard_and_gather_fns, match_partition_rules
from transformers import FalconForCausalLM
def main():
torch.manual_seed(42)
FalconConfig, FlaxFalconForCausalLM, transform_fn = get_modules_by_type("falcon")
config = FalconConfig(
vocab_size=1200,
hidden_size=256,
num_attention_heads=8,
num_hidden_layers=2,
gradient_checkpointing="",
alibi=False,
)
torch_model = FalconForCausalLM(
config=copy.deepcopy(config)
)
easy_model = FlaxFalconForCausalLM(
config=config
)
partition_specs = match_partition_rules(config.get_partition_rules(True), easy_model.params_shape_tree)
shard_fns, gather_fns = make_shard_and_gather_fns(
partition_specs=partition_specs,
dtype_specs=jax.numpy.float16
)
pytorch_dict = torch_model.state_dict()
with config.jax_mesh():
params = transform_fn(
pytorch_dict,
device=jax.devices("cpu")[0], # This got no use but incase that some key missmatch and not getting
# Kwargs req error we just pass that (No any params will be load on CPU for suer :) )
shard_fns=flax.traverse_util.flatten_dict(shard_fns)
)
print("Sharded Successfully")
if __name__ == "__main__":
main()
Cool. Can you serve up models for continuous batching to support multiple users with EasyDel? Also, which AMD GPUs are you testing on? I'm using AMD Mi50.
yes which serve core you want to use Torch or JAX
I am currently using my old PC Rx 580 as I don't have access to a massive AMD GPU. However, I am pleased to share that it is working well in every environment.
OK - this is good to know. I will have to try this this week. If you want to make a quick $100, happy to pay you if you get it working on my GPU server. I can give you ssh access.
and Mixtral models working fine as I checked them you can give the repo id for MPT model and Mixtral model to check those models too
What repo id were you referencing?
I'm guessing this script doesn't automatically make use of shading and flash attention? https://github.com/erfanzar/EasyDeL/blob/main/examples/serving/causal-lm/llama-2-chat.py
Running
python -m examples.serving.causal-lm.llama-2-chat --pretrained_model_name_or_path="meta-llama/Llama-2-7b-chat-hf" --max_length=4096 --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype="fp16" --use_prefix_tokenizer
against fff297dc2e96703bb19a42e8ee86c940549c27c6
fails with:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/ai-data/git/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 145, in <module>
server = Llama2Host.from_torch_pretrained(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 428, in from_torch_pretrained
return cls.from_parameters(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 507, in from_parameters
server.compile(verbose=verbose)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 538, in compile
for r, a in self.process(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 838, in process
predicted_token = self.greedy_generate(**inputs_to_gen) if greedy else self.generate(**inputs_to_gen)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 589, in greedy_generate
return self.greedy_generate_function(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 211, in greedy_generate
predict = model.generate(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 421, in generate
return self._greedy_search(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 640, in _greedy_search
state = greedy_search_body_fn(state)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 616, in greedy_search_body_fn
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 789, in __call__
outputs = self.module.apply(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1911, in apply
return apply(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/core/scope.py", line 1080, in wrapper
y = fn(root, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 2572, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1104, in __call__
outputs = self.model(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1001, in __call__
outputs = self.layers(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 893, in __call__
layer_outputs = block(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 558, in __call__
attn_outputs = self.self_attn(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/transforms.py", line 353, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/core/lift.py", line 270, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py", line 285, in fun_remat
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py", line 375, in _trace_to_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/transforms.py", line 345, in core_fn
res = fn(cloned, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 584, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/flax/linen/module.py", line 1101, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 390, in __call__
attentions = self.attention_performer.__call__(
File "/home/yehowshua/.envs/ai-dev/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 147, in __call__
assert key_states.shape == (
AssertionError:
query_states, key_states, value_states and bias shapes must be like
query_states Shape : [batch_size, num_attention_heads(32), q_seq_len, head_dims(128)]
key_states Shape : [batch_size, num_attention_heads(32), kv_seq_len, head_dims(128)]
value_states Shape : [batch_size, num_attention_heads(32), kv_seq_len, head_dims(128)]
bias Shape : [batch_size, num_attention_heads(32), q_seq_len, kv_seq_len]
I'm familiar with transformers and attention, so I could dig into this issue myself and see what's causing - but I'd much rather pay somebody to dig into this if that's an option.
It's ok you can email me at erfanzare810@gmail.com and share informations with me about project that you want to handel
Just sent you an e-mail from yehowshua@theperfect.computer
Not exactly a bug, but I'm about to try to get EasyDel to work on some AMD GPU servers I've got, and might need some help. Would it be possible to pay for support to get EasyDel working on these servers?