Closed coolrazor007 closed 11 months ago
sorry for that in last version (last night) i changed JaxServer and PytorchServer ill update exampels soon
fixed
I pulled down main and ran it again but I'm getting the same error. More context:
root@c392a887d2d0:/app/EasyDeL# python3 -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer repo id : mistralai/Mistral-7B-v0.1 contains auto format : False max length : 4096 max new tokens : 2048 max stream tokens : 32 temperature : 0.6 top p : 0.95 top k : 50 logging : False mesh axes names : ['dp', 'fsdp', 'mp'] mesh axes shape : [1, -1, 1] dtype : fp16 use prefix tokenizer : True You are using a model of type mistral to instantiate a model of type llama. This is not supported for all configurations of models and can yield errors. You are using a model of type mistral to instantiate a model of type llama. This is not supported for all configurations of models and can yield errors. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00, 7.86s/it] Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 210, in <module> server = Llama2Host.load_from_torch( File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 54, in load_from_torch return cls.load_from_params( File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 328, in load_from_params server = cls(config=config) File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 38, in __init__ super().__init__(config=config) File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 109, in __init__ assert config is None or isinstance(config, JaxServerConfig), 'config can be None or JaxServerConfig Type' AssertionError: config can be None or JaxServerConfig Type
here's the part you create your config for JaxServer:
configs = JaxServerConfig(
contains_auto_format=args.contains_auto_format,
max_length=args.max_length,
max_new_tokens=args.max_new_tokens,
max_stream_tokens=args.max_stream_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
logging=args.logging,
mesh_axes_names=args.mesh_axes_names,
mesh_axes_shape=args.mesh_axes_shape,
dtype=args.dtype,
use_prefix_tokenizer=args.use_prefix_tokenizer
)
and this is where you will get an error:
assert config is None or isinstance(config, JaxServerConfig), 'config can be None or JaxServerConfig Type'
i ran the same code as you did and the result was successful (on Kaggle TPU and my own workspace)
can you please provide me with more information like your python version your jax and jaxlib version
I'm running on TPU-VM at GCP currently. Using the Ubuntu 22.04 version of VM with a Ubuntu 22.04 Docker container. I'm running your code in the container but have confirmed that jax can see the TPUs with this: import jax print(jax.device_count()) 8
Python 3.10.12 Dunno if this matters, but "python3 ..." works whereas "python ..." does not. I could make symlink or an alias if need be.
Here is my pip freeze (probably too much context): absl-py==2.0.0 aiofiles==23.2.1 aiohttp==3.8.6 aiosignal==1.3.1 altair==5.1.2 annotated-types==0.6.0 anyio==3.7.1 appdirs==1.4.4 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.1.0 cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.1 chex==0.1.84 click==8.1.7 cloud-tpu-client==0.10 colorama==0.4.6 contextlib2==21.6.0 contourpy==1.1.1 cycler==0.12.1 datasets==2.14.6 decorator==5.1.1 diffusers==0.21.4 dill==0.3.7 docker-pycreds==0.4.0 EasyDeL==0.0.34 einops==0.7.0 etils==1.5.2 exceptiongroup==1.1.3 executing==2.0.1 fastapi==0.104.1 ffmpy==0.3.1 filelock==3.13.1 FJUtils==0.0.21 flax==0.7.5 fonttools==4.43.1 frozenlist==1.4.0 fsspec==2023.10.0 ftfy==6.1.1 gitdb==4.0.11 GitPython==3.1.40 google-api-core==1.34.0 google-api-python-client==1.8.0 google-auth==2.23.4 google-auth-httplib2==0.1.1 google-auth-oauthlib==1.1.0 googleapis-common-protos==1.61.0 gradio==4.0.2 gradio_client==0.7.0 grpcio==1.59.2 h11==0.14.0 httpcore==0.18.0 httplib2==0.22.0 httpx==0.25.0 huggingface-hub==0.17.3 idna==3.4 importlib-metadata==6.8.0 importlib-resources==6.1.0 ipython==8.17.2 jax==0.4.19 jaxlib==0.4.19 jedi==0.19.1 Jinja2==3.1.2 jsonschema==4.19.2 jsonschema-specifications==2023.7.1 kiwisolver==1.4.5 libtpu-nightly==0.1.dev20231018 Markdown==3.5.1 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.8.1 matplotlib-inline==0.1.6 mdurl==0.1.2 ml-collections==0.1.1 ml-dtypes==0.3.1 mpmath==1.3.0 msgpack==1.0.7 multidict==6.0.4 multiprocess==0.70.15 nest-asyncio==1.5.8 networkx==3.2.1 numpy==1.26.1 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.52 nvidia-nvtx-cu12==12.1.105 oauth2client==4.1.3 oauthlib==3.2.2 opt-einsum==3.3.0 optax==0.1.7 orbax-checkpoint==0.4.1 orjson==3.9.10 packaging==23.2 pandas==2.1.2 parso==0.8.3 pathtools==0.1.2 pexpect==4.8.0 Pillow==10.1.0 prompt-toolkit==3.0.39 protobuf==3.20.3 psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==14.0.0 pyasn1==0.5.0 pyasn1-modules==0.3.0 pydantic==2.4.2 pydantic_core==2.10.1 pydub==0.25.1 Pygments==2.16.1 pyparsing==3.1.1 python-dateutil==2.8.2 python-multipart==0.0.6 pytz==2023.3.post1 PyYAML==6.0.1 referencing==0.30.2 regex==2023.10.3 requests==2.31.0 requests-oauthlib==1.3.1 rich==13.6.0 rpds-py==0.10.6 rsa==4.9 safetensors==0.4.0 scipy==1.11.3 semantic-version==2.10.0 sentry-sdk==1.33.1 setproctitle==1.3.3 shellingham==1.5.4 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 stack-data==0.6.3 starlette==0.27.0 sympy==1.12 tensorboard==2.15.0 tensorboard-data-server==0.7.2 tensorstore==0.1.46 tokenizers==0.14.1 tomlkit==0.12.0 toolz==0.12.0 torch==2.1.0 torch-xla==2.1.0 torchvision==0.16.0 tqdm==4.66.1 traitlets==5.13.0 transformers==4.34.1 triton==2.1.0 typer==0.9.0 typing==3.7.4.3 typing_extensions==4.8.0 tzdata==2023.3 uritemplate==3.0.1 urllib3==2.0.7 uvicorn==0.23.2 wandb==0.15.12 wcwidth==0.2.9 websockets==11.0.3 Werkzeug==3.0.1 xxhash==3.4.1 yarl==1.9.2 zipp==3.17.0
do you want me to give you a docker to run ? or edit the code for you cause no matter how i try i won't get any error ;\
So on a fresh environment what's the process exactly? This is what I'm trying now:
git clone [EasyDeL]
cd easydel
./install.sh
(when I tried "python -m examples.serving.causal-lm.llama-2-chat" it complained about missing packages)
pip install -r requirements.txt
python -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer
When it loads I click the gradio.live link and try a message. It errors. Here's the error on the console:
This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/gradio/routes.py", line 488, in run_predict
output = await app.get_blocks().process_api(
File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1431, in process_api
result = await self.call_function(
File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1103, in call_function
prediction = await anyio.to_thread.run_sync(
File "/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 807, in run
result = context.run(func, *args)
File "/usr/local/lib/python3.10/dist-packages/gradio/utils.py", line 707, in wrapper
response = f(*args, **kwargs)
TypeError: Llama2Host.process_gradio_chat() takes from 5 to 6 positional arguments but 7 were given
fixed
I keep trying and it just won't work for me. If you have a Dockerfile handy that may help.
Here's basic format I keep following:
git clone [EasyDeL]
cd easydel
pip install easydel
./install.sh
ERROR:
python -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer
Traceback (most recent call last):
File "/usr/local/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/local/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 3, in <module>
import EasyDel
File "/usr/local/lib/python3.9/site-packages/EasyDel/__init__.py", line 1, in <module>
from .utils import make_shard_and_gather_fns, get_mesh
File "/usr/local/lib/python3.9/site-packages/EasyDel/utils/__init__.py", line 2, in <module>
from .utils import get_mesh, Timers, Timer, prefix_str, prefix_print, names_in_mesh, with_sharding_constraint, \
File "/usr/local/lib/python3.9/site-packages/EasyDel/utils/utils.py", line 1, in <module>
import jax
File "/usr/local/lib/python3.9/site-packages/jax/__init__.py", line 39, in <module>
from jax import config as _config_module
File "/usr/local/lib/python3.9/site-packages/jax/config.py", line 15, in <module>
from jax._src.config import config as _deprecated_config # noqa: F401
File "/usr/local/lib/python3.9/site-packages/jax/_src/config.py", line 28, in <module>
from jax._src import lib
File "/usr/local/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 74, in <module>
version = check_jaxlib_version(
File "/usr/local/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 63, in check_jaxlib_version
raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.10, but this version of jax requires version >= 0.4.14.
So I do this:
pip install --force-reinstall -v "jaxlib>=0.4.14" "jax>=0.4.14"
Which shows this:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
easydel 0.0.35 requires datasets==2.14.3, but you have datasets 2.14.6 which is incompatible.
Successfully installed importlib-metadata-6.8.0 jax-0.4.20 jaxlib-0.4.20 ml-dtypes-0.3.1 numpy-1.26.1 opt-einsum-3.3.0 scipy-1.11.3 zipp-3.17.0
Then I get this:
root@v2-8-tpu-vm-4:/app/EasyDeL# python -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer
Traceback (most recent call last):
File "/usr/local/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/local/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 3, in <module>
import EasyDel
File "/usr/local/lib/python3.9/site-packages/EasyDel/__init__.py", line 2, in <module>
from .modules import (FlaxLlamaModel, FlaxLlamaForCausalLM, LlamaConfig,
File "/usr/local/lib/python3.9/site-packages/EasyDel/modules/__init__.py", line 1, in <module>
from .llama import FlaxLlamaModel, FlaxLlamaForCausalLM, LlamaConfig
File "/usr/local/lib/python3.9/site-packages/EasyDel/modules/llama/__init__.py", line 1, in <module>
from .modelling_llama_flax import FlaxLlamaModel, FlaxLlamaForCausalLM, LlamaConfig
File "/usr/local/lib/python3.9/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 16, in <module>
from fjformer.attention import efficient_attention
File "/usr/local/lib/python3.9/site-packages/fjformer/__init__.py", line 17, in <module>
from fjformer.datasets import (
File "/usr/local/lib/python3.9/site-packages/fjformer/datasets/__init__.py", line 1, in <module>
from .datasets import get_dataloader
File "/usr/local/lib/python3.9/site-packages/fjformer/datasets/datasets.py", line 4, in <module>
from torch.utils.data import DataLoader
ModuleNotFoundError: No module named 'torch'
To fix that I do this:
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
Then the serve script will load but it won't see the TPUs:
python -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer
A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 571/571 [00:00<00:00, 55.2kB/s]
...
I'd appreciate a Dockerfile/docker-compose.yml or any advice you may have. I've tried other Jax examples, lightning, llama.cpp, Axolotl and other approaches to no avail. For me your repo has gotten me closest I think.
Current Dockerfile:
FROM python:3.9-bullseye
# Set work directory
WORKDIR /app
RUN apt-get update
RUN apt-get install -y nano
RUN pip install --upgrade pip
CMD ["/bin/bash", "-c", "sleep infinity"]
Current docker-compose.yml:
version: '3'
services:
xla_custom: #Test out client applications here
container_name: xla_custom
build:
context: .
dockerfile: Dockerfile
env_file:
- .env
volumes:
- ./:/app/:rw
- /nfs_share/models/:/models/
restart: unless-stopped
devices:
- "/dev/tpu_common_0:/dev/tpu_common_0"
- "/dev/tpu_common_1:/dev/tpu_common_1"
- "/dev/tpu_common_2:/dev/tpu_common_2"
- "/dev/tpu_common_3:/dev/tpu_common_3"
- "/dev/accel0:/dev/accel0"
- "/dev/accel1:/dev/accel1"
- "/dev/accel2:/dev/accel2"
- "/dev/accel3:/dev/accel3"
network_mode: "host"
privileged: true
Current .env:
# Disable CUDA for PyTorch and ensure the pre-built wheel works
USE_CUDA= "0"
USE_MPI="0"
PJRT_DEVICE="TPU"
# Whether to build for TPUVM mode
TPUVM_MODE=1
BUNDLE_LIBTPU=1
JAX_PLATFORMS=''
you don't need a docker file actually to run the code it's far more simple than what you think it is first of all
you are using LlamaHost
to launch a mistral model for mistral model you have to customize the Serve with provided documents or else finetune the mistral model with the llama prompting method (llama2 huggingface prompting style)
but you still should be fine to go
here's a basic DockerFile you may need to setup env
FROM ubuntu:latest
LABEL authors="Erfan Zare Chavoshi"
FROM python:3.11
WORKDIR /app
RUN apt-get update && apt-get upgrade -y -q
RUN apt-get install golang -y -q
ARG device
RUN if [ "$device" = "tpu" ]; then \
pip install torch torchvision torchaudio --index-url ... ... https://download.pytorch.org/whl/cpu \
&& pip install jax[tpu]==0.4.10 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
else \
pip install torch torchvision torchaudio && pip install jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \
fi
# Install dependencies
RUN pip install \
chex \
typing \
jax>=0.4.10 \
jaxlib>=0.4.10 \
flax \
fjformer>=0.0.7 \
transformers>=4.33.0 \
einops \
optax \
msgpack \
ipython \
tqdm \
pydantic==2.4.2 \
datasets==2.14.3 \
setuptools \
gradio \
distrax \
rlax \
wandb>=0.15.9 \
tensorboard \
pydantic_core==2.11.0
# Copy your application code
COPY . /app
ENTRYPOINT ["top", "-b"]
# Customize the rest ...
but you can simply run these codes and then you should be fine to go
!pip install fjformer==0.0.7 protobuf==3.20.0 gradio==3.40.0 -q
!pip install jax[tpu]==0.4.10 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
!apt-get update && apt-get upgrade -y -q
!apt-get install golang -y -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
and then clone repo and run codes from there or simply just run pip install easydel==0.0.36
or
%cd /kaggle/working
!git clone https://github.com/erfanzar/EasyDeL.git
%cd EasyDeL
!git pull
%cd lib/python
this should work fine and about your .env Variables you actually don't need them easydel with automatically set the args for these configs you should only pass the sharding dim for mesh and backend type to easydel like
backend='tpu'
# and for dims partitioning manages with a 3D mesh ('DP','FSDP','MP')
# for example for a fully use of FSDP method use
dims = '1,-1,1'
#this will automatically get mesh size for FSDP method
and you don't need torch_xla[tpu]
to run torch is only used in EasyDel for dataloaders
I'm very interested in getting this working. I am trying to get Mistral running on a TPU-VMs at GCP.
Using your example command: python3 -m examples.serving.causal-lm.llama-2-chat --repo_id='mistralai/Mistral-7B-v0.1' --max_length=4096 --max_new_tokens=2048 --max_stream_tokens=32 --temperature=0.6 --top_p=0.95 --top_k=50 --dtype='fp16' --use_prefix_tokenizer
I get this error after it downloads the model and loads the checkpoint shards: Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 210, in
server = Llama2Host.load_from_torch(
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 54, in load_from_torch
return cls.load_from_params(
File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 328, in load_from_params
server = cls(config=config)
File "/app/EasyDeL/examples/serving/causal-lm/llama-2-chat.py", line 38, in init
super().init(config=config)
File "/usr/local/lib/python3.10/dist-packages/EasyDel/serve/jax_serve.py", line 109, in init
assert config is None or isinstance(config, JaxServerConfig), 'config can be None or JaxServerConfig Type'
AssertionError: config can be None or JaxServerConfig Type