google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Error Running `run_ray_serve_interleave` with Llama3 8B #169

Open ryanaoleary opened 1 month ago

ryanaoleary commented 1 month ago

I'm receiving an error when attempting to run:

ray job submit -- python run_ray_serve_interleave.py  --tpu_chips=4 --num_hosts=1 --size=8B --model_name=llama-3 --batch_size=8 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --checkpoint_path=$output_ckpt_dir --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

on a single-host v4 TPU of 2x2x1 topology. The error is:

ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, [36mray::PyTorchRayWorker.__init__()[39m (pid=5137, ip=10.168.0.16, actor_id=243ec964a2f41eae1707d84404000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7abe4074add0>)
File "/home/ray/jetstream-pytorch/jetstream_pt/ray_worker.py", line 200, in __init__
 pt_model = model_exportable.Transformer(args, env)
File "/home/ray/jetstream-pytorch/jetstream_pt/third_party/llama/model_exportable.py", line 192, in __init__
 self.tok_embeddings = Embedding(
File "/home/ray/jetstream-pytorch/jetstream_pt/layers.py", line 57, in __init__
 table = torch.ones(
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
 result = fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_refs/__init__.py", line 4774, in ones
 size = utils.extract_shape_from_varargs(size)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 854, in extract_shape_from_varargs
 validate_shape(shape) # type: ignore[arg-type]
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 588, in validate_shape
 validate_dim_length(l)

This seems to be related to this logic in ray_worker.py that creates the pt_model:

env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)

should the model_type for llama models be hardcoded to llama-2?