I have several Android device nodes that I use for model inference. When the nodes start loading the model, the first node loads the model successfully. However, the second node encounters an error when loading the model:
loaded weights in 48509.05 ms, 5.85 GB loaded at 0.12 GB/s
╭───────────────────────────────────────── Exo Cluster (3 nodes) ──────────────────────────────────────────╮
│ │
│ _____ _____ │
Error processing tensor for shard Shard(model_id='/nasroot/models/Meta-Llama-3-8B/', start_layer=0,
end_layer=10, n_layers=32): size mismatched, can't reshape self.shape=(1, 1, 4096, 128256, 4096) ->
new_shape=(1, 1, 32, 128)
Traceback (most recent call last):
File "/root/exo/exo/orchestration/standard_node.py", line 221, in _process_tensor
result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard,
tensor, inference_state=inference_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/exo/exo/inference/tinygrad/inference.py", line 93, in infer_tensor
output = self.model(input_tensor, start_pos, TEMPERATURE)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/exo/exo/inference/tinygrad/models/llama.py", line 188, in __call__
return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature,
top_k, top_p, alpha_f, alpha_p)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/engine/jit.py", line 150, in
__call__
self.ret = self.fxn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/exo/exo/inference/tinygrad/models/llama.py", line 177, in forward
h = layer(h, start_pos, freqs_cis, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/exo/exo/inference/tinygrad/models/llama.py", line 100, in __call__
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/exo/exo/inference/tinygrad/models/llama.py", line 56, in __call__
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 3123, in _wrapper
ret = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 779, in reshape return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 38, in apply
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs),
ctx.requires_grad, None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/function.py", line 190, in forward
return x.reshape(shape)
^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/lazy.py", line 214, in reshape
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/shape/shapetracker.py", line 111, in
reshape
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return
ShapeTracker(self.views[0:-1] + (new_view,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/shape/view.py", line 277, in reshape
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
ValueError: size mismatched, can't reshape self.shape=(1, 1, 4096, 128256, 4096) -> new_shape=(1, 1, 32,
128)
SendTensor tensor shard=Shard(model_id='/nasroot/models/Meta-Llama-3-8B/', start_layer=8, end_layer=16,
n_layers=32) tensor=array([[[ 0.000971, 0.1025 , -0.01109 , ..., 0.3887 , 0.2261 ,
-0.02382 ],
[ 0.0868 , 0.01169 , -0.06076 , ..., -0.213 , -0.167 ,
0.01929 ],
[ 0.0471 , -0.05304 , 0.01163 , ..., -0.0695 , -0.0814 ,
0.007477],
...,
[-0.01105 , -0.01181 , -0.02911 , ..., 0.0731 , 0.01648 ,
-0.02768 ],
[ 0.03265 , 0.0093 , 0.011566, ..., -0.0092 , -0.01581 ,
-0.04117 ],
[ 0.0252 , -0.0665 , 0.02354 , ..., 0.05563 , -0.006683,
-0.03708 ]]], dtype=float16) request_id='dd039cd4-721f-4597-8028-b221c08209bf' result: None
Broadcasting opaque status: request_id='dd039cd4-721f-4597-8028-b221c08209bf' status='{"type":
"node_status", "node_id": "59ae1de7-3254-4340-9401-f12b2b04a6d6", "status": "start_process_tensor",
"base_shard": {"model_id": "/nasroot/models/Meta-Llama-3-8B/", "start_layer": 8, "end_layer": 16,
"n_layers": 32}, "shard": {"model_id": "/nasroot/models/Meta-Llama-3-8B/", "start_layer": 0, "end_layer":
10, "n_layers": 32}, "tensor_size": 122880, "tensor_shape": [1, 30, 4096], "request_id":
"dd039cd4-721f-4597-8028-b221c08209bf", "inference_state": "{\\"start_pos\\": 0, \\"n_captured_toks\\":
30}"}'
Run exo cmd:
DEBUG=9 SUPPORT_BF16=0 python main.py --node-port=8001
I'm not very familiar with the technical details of tensor processing and model loading. Any guidance or explanation on the root cause of this issue and potential solutions would be greatly appreciated.
I have several Android device nodes that I use for model inference. When the nodes start loading the model, the first node loads the model successfully. However, the second node encounters an error when loading the model:
Environment:
Android devices: xiaomi14pro termux root-distro Model: Meta-Llama-3-8B
Run exo cmd: DEBUG=9 SUPPORT_BF16=0 python main.py --node-port=8001
I'm not very familiar with the technical details of tensor processing and model loading. Any guidance or explanation on the root cause of this issue and potential solutions would be greatly appreciated.
Thank you for your help!