exo-explore / exo

Run your own AI cluster at home with everyday devices 📱💻 🖥️⌚
GNU General Public License v3.0
6.56k stars 342 forks source link

Error processing tensor for shard when loading model on Android devices #160

Open artistlu opened 3 weeks ago

artistlu commented 3 weeks ago

I have several Android device nodes that I use for model inference. When the nodes start loading the model, the first node loads the model successfully. However, the second node encounters an error when loading the model:

loaded weights in 48509.05 ms, 5.85 GB loaded at 0.12 GB/s
╭───────────────────────────────────────── Exo Cluster (3 nodes) ──────────────────────────────────────────╮
│                                                                                                          │
│                                                            _____  _____                                  │
Error processing tensor for shard Shard(model_id='/nasroot/models/Meta-Llama-3-8B/', start_layer=0, 
end_layer=10, n_layers=32): size mismatched, can't reshape self.shape=(1, 1, 4096, 128256, 4096) -> 
new_shape=(1, 1, 32, 128)
Traceback (most recent call last):
  File "/root/exo/exo/orchestration/standard_node.py", line 221, in _process_tensor
    result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, 
tensor, inference_state=inference_state)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/exo/exo/inference/tinygrad/inference.py", line 93, in infer_tensor
    output = self.model(input_tensor, start_pos, TEMPERATURE)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/exo/exo/inference/tinygrad/models/llama.py", line 188, in __call__
    return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature,
top_k, top_p, alpha_f, alpha_p)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/engine/jit.py", line 150, in 
__call__
    self.ret = self.fxn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/exo/exo/inference/tinygrad/models/llama.py", line 177, in forward
    h = layer(h, start_pos, freqs_cis, mask)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/exo/exo/inference/tinygrad/models/llama.py", line 100, in __call__
    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/exo/exo/inference/tinygrad/models/llama.py", line 56, in __call__
    xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 3123, in _wrapper
    ret = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 779, in reshape    return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/tensor.py", line 38, in apply
    ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), 
ctx.requires_grad, None
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/function.py", line 190, in forward
    return x.reshape(shape)
           ^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/lazy.py", line 214, in reshape
    def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
                                                               ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/shape/shapetracker.py", line 111, in
reshape
    if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return 
ShapeTracker(self.views[0:-1] + (new_view,))
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/exo/lib/python3.12/site-packages/tinygrad/shape/view.py", line 277, in reshape
    raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
ValueError: size mismatched, can't reshape self.shape=(1, 1, 4096, 128256, 4096) -> new_shape=(1, 1, 32, 
128)
SendTensor tensor shard=Shard(model_id='/nasroot/models/Meta-Llama-3-8B/', start_layer=8, end_layer=16, 
n_layers=32) tensor=array([[[ 0.000971,  0.1025  , -0.01109 , ...,  0.3887  ,  0.2261  ,
         -0.02382 ],
        [ 0.0868  ,  0.01169 , -0.06076 , ..., -0.213   , -0.167   ,
          0.01929 ],
        [ 0.0471  , -0.05304 ,  0.01163 , ..., -0.0695  , -0.0814  ,
          0.007477],
        ...,
        [-0.01105 , -0.01181 , -0.02911 , ...,  0.0731  ,  0.01648 ,
         -0.02768 ],
        [ 0.03265 ,  0.0093  ,  0.011566, ..., -0.0092  , -0.01581 ,
         -0.04117 ],
        [ 0.0252  , -0.0665  ,  0.02354 , ...,  0.05563 , -0.006683,
         -0.03708 ]]], dtype=float16) request_id='dd039cd4-721f-4597-8028-b221c08209bf' result: None
Broadcasting opaque status: request_id='dd039cd4-721f-4597-8028-b221c08209bf' status='{"type": 
"node_status", "node_id": "59ae1de7-3254-4340-9401-f12b2b04a6d6", "status": "start_process_tensor", 
"base_shard": {"model_id": "/nasroot/models/Meta-Llama-3-8B/", "start_layer": 8, "end_layer": 16, 
"n_layers": 32}, "shard": {"model_id": "/nasroot/models/Meta-Llama-3-8B/", "start_layer": 0, "end_layer": 
10, "n_layers": 32}, "tensor_size": 122880, "tensor_shape": [1, 30, 4096], "request_id": 
"dd039cd4-721f-4597-8028-b221c08209bf", "inference_state": "{\\"start_pos\\": 0, \\"n_captured_toks\\": 
30}"}'

Environment:

Android devices: xiaomi14pro termux root-distro Model: Meta-Llama-3-8B

Run exo cmd: DEBUG=9 SUPPORT_BF16=0 python main.py --node-port=8001

I'm not very familiar with the technical details of tensor processing and model loading. Any guidance or explanation on the root cause of this issue and potential solutions would be greatly appreciated.

Thank you for your help!