on a single-host v4 TPU of 2x2x1 topology. The error is:
ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, [36mray::PyTorchRayWorker.__init__()[39m (pid=5137, ip=10.168.0.16, actor_id=243ec964a2f41eae1707d84404000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7abe4074add0>)
File "/home/ray/jetstream-pytorch/jetstream_pt/ray_worker.py", line 200, in __init__
pt_model = model_exportable.Transformer(args, env)
File "/home/ray/jetstream-pytorch/jetstream_pt/third_party/llama/model_exportable.py", line 192, in __init__
self.tok_embeddings = Embedding(
File "/home/ray/jetstream-pytorch/jetstream_pt/layers.py", line 57, in __init__
table = torch.ones(
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
result = fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_refs/__init__.py", line 4774, in ones
size = utils.extract_shape_from_varargs(size)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 854, in extract_shape_from_varargs
validate_shape(shape) # type: ignore[arg-type]
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 588, in validate_shape
validate_dim_length(l)
This seems to be related to this logic in ray_worker.py that creates the pt_model:
I'm receiving an error when attempting to run:
on a single-host v4 TPU of 2x2x1 topology. The error is:
This seems to be related to this logic in
ray_worker.py
that creates thept_model
:should the
model_type
for llama models be hardcoded tollama-2
?