JetStream Engine implementation in PyTorch
The latest release version is tagged with jetstream-v0.2.3
. If you are running the release version
Please follow the README of the that version here:
https://github.com/google/jetstream-pytorch/blob/jetstream-v0.2.3/README.md
Commandline Flags might have changed between the release version to HEAD.
gcloud compute config-ssh
gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone"
Follow the steps in
git clone https://github.com/google/jetstream-pytorch.git
git checkout jetstream-v0.2.3
(optional) Create a virtual env using venv
or conda
and activate it.
cd jetstream-pytorch
source install_everything.sh
Following instructions here:
After you have downloaded the weights, it will also download a tokenizer.model
file that is
the tokenizer that we will use.
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
# Install huggingface-cli and login if it's not set up.
pip install -U "huggingface_hub[cli]"
huggingface-cli login
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended.
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
export quantize_weights=True # Whether to quantize weights
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights
Set tokenizer path
export tokenizer_path=tokenizer model file path
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
Here is an example to run the server with llama2 7B config.
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
Now you can fire gRPC to it.
Optional flags:
--shard_on_batch=1
This makes the model to shard on
the batch dimension. I.e. this runs in data parallel mode instead of model
parallel. This will ignore the sharding config. This is recommended for Gemma 2B
model, because Gemma 2B is small enough to fit on a single TPU chip.
--sharding_config=<path>
This makes use of alternative sharding config instead of
the ones in default_shardings directory.
Below are steps run server with ray:
Login host 0 VM, start ray head with below command:
ray start --head
Login other host VMs, start ray head with below command:
ray start --address='$ip:$port'
Note: Get address ip and port information from ray head.
Here is an example to run the server with ray for llama2 7B model:
python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
Start the server and then go to the deps/JetStream folder (downloaded during install_everything.sh
)
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-mode=sampled --model=$model_name
Please look at deps/JetStream/benchmarks/README.md
for more information.
If running on GKE:
kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml
or
kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml
Single-host (Llama2 7B):
export RAY_ADDRESS=http://localhost:8265
kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=4 --num_hosts=1 --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
Multi-host (Llama2 70B):
export RAY_ADDRESS=http://localhost:8265
kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=8 --num_hosts=2 --size=70b --model_name=llama-2 --batch_size=8 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
Port-forward to port 8888 for gRPC:
kubectl port-forward svc/example-cluster-kuberay-head-svc 8888:8888 &
Sample python script:
import requests
import os
import grpc
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
prompt = "What are the top 5 languages?"
channel = grpc.insecure_channel("localhost:8888")
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
request = jetstream_pb2.DecodeRequest(
text_content=jetstream_pb2.DecodeRequest.TextContent(
text=prompt
),
priority=0,
max_tokens=2000,
)
response = stub.Decode(request)
output = []
for resp in response:
output.extend(resp.stream_content.samples[0].text)
text_output = "".join(output)
print(f"Prompt: {prompt}")
print(f"Response: {text_output}")
Fix:
Fix: