erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
209 stars 24 forks source link

Error while serving model as per documentation, #80

Closed jchauhan closed 10 months ago

jchauhan commented 10 months ago

Describe the bug Error while serving a toy model using EasyDel with the following exception.

To Reproduce

Context

python sev-tiny.py
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/sev-tiny.py", line 23, in <module>
    params=model.params,
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 271, in params
    raise ValueError(
ValueError: `params` cannot be accessed from model when the model is created with `_do_init=False`. You must call `init_weights` manually and store the params outside of the model and pass it explicitly where needed.
erfanzar commented 10 months ago

You are trying to access the frozen model and that's empty, ill create an example for you in next commit.

erfanzar commented 10 months ago

i have tested all of the examples and all of them are working fine use this as a mindset example


from EasyDel import (
    JAXServer,
    JAXServerConfig
)

server_config = JAXServerConfig(
    top_k=40,
    top_p=0.9,
    temperature=0.8,
    max_length=2048,
    max_new_tokens=4096,
    max_compile_tokens=128
)

model_id = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

server = JAXServer.from_torch_pretrained(
    server_config,
    model_id
)

tkns = 0
for response, tokens_generated in server.process(
    "```python\ndef make_pytorch_linear("
):
    print(response[tkns:], end="")
    tkns += len(response)