erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Running every model gives an error Shapes must be 1D sequences of concrete values of integer type #76

Closed jchauhan closed 5 months ago

jchauhan commented 5 months ago

To Reproduce Run it on any model and it gives the following error

python -m examples.serving.causal-lm.llama-2-chat   --pretrained_model_name_or_path="mediocredev/open-llama-3b-v2-chat" --max_length=4096   --max_new_tokens=2048 --max_compile_tokens=32 --temperature=0.6   --top_p=0.95 --top_k=50   --dtype="fp32" --use_prefix_tokenizer
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.21it/s]
JAXServerConfig(host='0.0.0.0', port=2059, batch_size=1, contains_auto_format=False, max_length=4096, max_new_tokens=2048, max_compile_tokens=32, temperature=0.6, top_p=0.95, top_k=50, logging=False, mesh_axes_names=['dp', 'fsdp', 'tp', 'sp'], mesh_axes_shape=[(1, -1, 1, 1)], generation_ps=PartitionSpec('dp', 'fsdp'), dtype='fp32', stream_tokens_for_gradio=True, use_prefix_tokenizer=True, pre_compile=True)
Traceback (most recent call last):
  File "/home/***/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/***/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 195, in <module>
    server = Llama2Host.load_from_torch(
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 66, in load_from_torch
    return cls.load_from_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 476, in load_from_params
    server = cls(config=config)
  File "/home/***/research/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 40, in __init__
    super().__init__(config=config)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py", line 123, in __init__
    array = jnp.ones((len(jax.devices()), 1)).reshape(self.config.mesh_axes_shape)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 155, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 123, in _compute_newshape
    newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 2130, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [(1, -1, 1, 1)].
erfanzar commented 5 months ago

Try again it's fixed for example